]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
detect {opensql} and {stop} sections
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2022 16:18:23 +0000 (12:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2022 18:16:34 +0000 (14:16 -0400)
this so that I can still have
{opensql} and {stop} sections in non-console python.

this isn't the norm but I would prefer if I dont have to
be 100% strict about it

also maintaining {sql} / {stop} being at the start
of a code line.  this is more prevalent in 1.4.

Change-Id: Iaf748b7ff1120e21f729c2fd794d9b8a33d83170

doc/build/core/constraints.rst
doc/build/core/ddl.rst
doc/build/core/metadata.rst
doc/build/orm/mapped_attributes.rst
doc/build/orm/queryguide/inheritance.rst
doc/build/orm/queryguide/relationships.rst
doc/build/orm/self_referential.rst
tools/format_docs_code.py

index f5a4b5f134d3770ff8454e1d9344cfe361fe5c2b..f2ee5a0a65ede344ea949c99c62cfed3d4cb7099 100644 (file)
@@ -372,17 +372,16 @@ MySQL.
     from sqlalchemy import CheckConstraint
 
     metadata_obj = MetaData()
-    mytable = Table('mytable', metadata_obj,
-
+    mytable = Table(
+        "mytable",
+        metadata_obj,
         # per-column CHECK constraint
-        Column('col1', Integer, CheckConstraint('col1>5')),
-
-        Column('col2', Integer),
-        Column('col3', Integer),
-
+        Column("col1", Integer, CheckConstraint("col1>5")),
+        Column("col2", Integer),
+        Column("col3", Integer),
         # table level CHECK constraint.  'name' is optional.
-        CheckConstraint('col2 > col3 + 5', name='check1')
-        )
+        CheckConstraint("col2 > col3 + 5", name="check1"),
+    )
 
     {sql}mytable.create(engine)
     CREATE TABLE mytable (
@@ -852,25 +851,24 @@ INDEX" is issued right after the create statements for the table:
 .. sourcecode:: python+sql
 
     metadata_obj = MetaData()
-    mytable = Table('mytable', metadata_obj,
+    mytable = Table(
+        "mytable",
+        metadata_obj,
         # an indexed column, with index "ix_mytable_col1"
-        Column('col1', Integer, index=True),
-
+        Column("col1", Integer, index=True),
         # a uniquely indexed column with index "ix_mytable_col2"
-        Column('col2', Integer, index=True, unique=True),
-
-        Column('col3', Integer),
-        Column('col4', Integer),
-
-        Column('col5', Integer),
-        Column('col6', Integer),
-        )
+        Column("col2", Integer, index=True, unique=True),
+        Column("col3", Integer),
+        Column("col4", Integer),
+        Column("col5", Integer),
+        Column("col6", Integer),
+    )
 
     # place an index on col3, col4
-    Index('idx_col34', mytable.c.col3, mytable.c.col4)
+    Index("idx_col34", mytable.c.col3, mytable.c.col4)
 
     # place a unique index on col5, col6
-    Index('myindex', mytable.c.col5, mytable.c.col6, unique=True)
+    Index("myindex", mytable.c.col5, mytable.c.col6, unique=True)
 
     {sql}mytable.create(engine)
     CREATE TABLE mytable (
@@ -910,7 +908,7 @@ The :class:`~sqlalchemy.schema.Index` object also supports its own ``create()``
 
 .. sourcecode:: python+sql
 
-    i = Index('someindex', mytable.c.col5)
+    i = Index("someindex", mytable.c.col5)
     {sql}i.create(engine)
     CREATE INDEX someindex ON mytable (col5){stop}
 
index 35e3c37f4c70e5e81fa32bff1e729437992af614..ea99ac7c5030d708293faa53696b6272da963542 100644 (file)
@@ -99,27 +99,29 @@ first looking within the PostgreSQL catalogs to see if it exists:
 
     def should_create(ddl, target, connection, **kw):
         row = connection.execute(
-            "select conname from pg_constraint where conname='%s'" %
-            ddl.element.name).scalar()
+            "select conname from pg_constraint where conname='%s'" % ddl.element.name
+        ).scalar()
         return not bool(row)
 
+
     def should_drop(ddl, target, connection, **kw):
         return not should_create(ddl, target, connection, **kw)
 
+
     event.listen(
         users,
         "after_create",
         DDL(
             "ALTER TABLE users ADD CONSTRAINT "
             "cst_user_name_length CHECK (length(user_name) >= 8)"
-        ).execute_if(callable_=should_create)
+        ).execute_if(callable_=should_create),
     )
     event.listen(
         users,
         "before_drop",
-        DDL(
-            "ALTER TABLE users DROP CONSTRAINT cst_user_name_length"
-        ).execute_if(callable_=should_drop)
+        DDL("ALTER TABLE users DROP CONSTRAINT cst_user_name_length").execute_if(
+            callable_=should_drop
+        ),
     )
 
     {sql}users.create(engine)
index 04fcdb0b9ebd30e4b2c592a19462fb1ce0aaf81e..1765ff43980d4070169e7f165528e2b101e0e28f 100644 (file)
@@ -174,22 +174,26 @@ will issue the CREATE statements:
 
 .. sourcecode:: python+sql
 
-    engine = create_engine('sqlite:///:memory:')
+    engine = create_engine("sqlite:///:memory:")
 
     metadata_obj = MetaData()
 
-    user = Table('user', metadata_obj,
-        Column('user_id', Integer, primary_key=True),
-        Column('user_name', String(16), nullable=False),
-        Column('email_address', String(60), key='email'),
-        Column('nickname', String(50), nullable=False)
+    user = Table(
+        "user",
+        metadata_obj,
+        Column("user_id", Integer, primary_key=True),
+        Column("user_name", String(16), nullable=False),
+        Column("email_address", String(60), key="email"),
+        Column("nickname", String(50), nullable=False),
     )
 
-    user_prefs = Table('user_prefs', metadata_obj,
-        Column('pref_id', Integer, primary_key=True),
-        Column('user_id', Integer, ForeignKey("user.user_id"), nullable=False),
-        Column('pref_name', String(40), nullable=False),
-        Column('pref_value', String(100))
+    user_prefs = Table(
+        "user_prefs",
+        metadata_obj,
+        Column("pref_id", Integer, primary_key=True),
+        Column("user_id", Integer, ForeignKey("user.user_id"), nullable=False),
+        Column("pref_name", String(40), nullable=False),
+        Column("pref_value", String(100)),
     )
 
     {sql}metadata_obj.create_all(engine)
@@ -225,14 +229,16 @@ default issue the CREATE or DROP regardless of the table being present:
 
 .. sourcecode:: python+sql
 
-    engine = create_engine('sqlite:///:memory:')
+    engine = create_engine("sqlite:///:memory:")
 
     metadata_obj = MetaData()
 
-    employees = Table('employees', metadata_obj,
-        Column('employee_id', Integer, primary_key=True),
-        Column('employee_name', String(60), nullable=False, key='name'),
-        Column('employee_dept', Integer, ForeignKey("departments.department_id"))
+    employees = Table(
+        "employees",
+        metadata_obj,
+        Column("employee_id", Integer, primary_key=True),
+        Column("employee_name", String(60), nullable=False, key="name"),
+        Column("employee_dept", Integer, ForeignKey("departments.department_id")),
     )
     {sql}employees.create(engine)
     CREATE TABLE employees(
index 1a8305b481af32b5e358e456957e4593a508a12f..3e8c4c395ee29ecc949a7e36cab4653852f99644 100644 (file)
@@ -182,6 +182,7 @@ that is, from the ``EmailAddress`` class directly:
 
     from sqlalchemy.orm import Session
     from sqlalchemy import select
+
     session = Session()
 
     {sql}address = session.scalars(
@@ -193,7 +194,7 @@ that is, from the ``EmailAddress`` class directly:
     ('address@example.com',)
     {stop}
 
-    address.email = 'otheraddress@example.com'
+    address.email = "otheraddress@example.com"
     {sql}session.commit()
     UPDATE address SET email=? WHERE address.id = ?
     ('otheraddress@example.com', 1)
index ea5316cecd131899fdb0923944a756a52f13c7b0..7905dde2e2036dc738c0a3299186e7f553bfbabb 100644 (file)
@@ -861,7 +861,7 @@ the ``Engineer`` entity is performed::
     WHERE employee.type IN (?) ORDER BY employee.id
     [...] ('engineer',)
     {stop}>>> for obj in objects:
-    ...    print(f"{obj}")
+    ...     print(f"{obj}")
     Engineer('SpongeBob')
     Engineer('Squidward')
 
@@ -917,7 +917,7 @@ efficient for single-inheritance mappers::
     FROM employee ORDER BY employee.id
     [...] ()
     {stop}>>> for obj in objects:
-    ...    print(f"{obj}")
+    ...     print(f"{obj}")
     Manager('Mr. Krabs')
     Engineer('SpongeBob')
     Engineer('Squidward')
index e99cb3f24aaa92840d6074cf234c662f3745bf34..4b37365322c834401da3a25fb21920f31dc1bb2b 100644 (file)
@@ -1117,9 +1117,9 @@ the specific :func:`_orm.aliased` construct to be passed:
     # construct a statement which expects the "addresses" results
 
     stmt = (
-       select(User).
-       outerjoin(User.addresses.of_type(adalias)).
-       options(contains_eager(User.addresses.of_type(adalias)))
+        select(User)
+        .outerjoin(User.addresses.of_type(adalias))
+        .options(contains_eager(User.addresses.of_type(adalias)))
     )
 
     # get results normally
index ba73a2ad930704dfd5249095056ea419bf15761c..b0ca90ae52f0e191ed825c9aabb1bb4c1279040e 100644 (file)
@@ -179,13 +179,12 @@ configured via :paramref:`~.relationships.join_depth`:
 .. sourcecode:: python+sql
 
     class Node(Base):
-        __tablename__ = 'node'
+        __tablename__ = "node"
         id = mapped_column(Integer, primary_key=True)
-        parent_id = mapped_column(Integer, ForeignKey('node.id'))
+        parent_id = mapped_column(Integer, ForeignKey("node.id"))
         data = mapped_column(String(50))
-        children = relationship("Node",
-                        lazy="joined",
-                        join_depth=2)
+        children = relationship("Node", lazy="joined", join_depth=2)
+
 
     session.scalars(select(Node)).all()
     {opensql}SELECT node_1.id AS node_1_id,
index b01290a7ca5d87b9f1c309751186109ea26add29..cccb1740052b37a33d8cdf6979de64bc3a5fffe1 100644 (file)
@@ -13,7 +13,15 @@ from black.mode import TargetVersion
 
 home = Path(__file__).parent.parent
 
-_Block = list[tuple[str, int, str | None, str]]
+_Block = list[
+    tuple[
+        str,
+        int,
+        str | None,
+        str | None,
+        str,
+    ]
+]
 
 
 def _format_block(
@@ -24,7 +32,7 @@ def _format_block(
 ) -> list[str]:
     if not is_doctest:
         # The first line may have additional padding. Remove then restore later
-        add_padding = start_space.match(input_block[0][3]).groups()[0]
+        add_padding = start_space.match(input_block[0][4]).groups()[0]
         skip = len(add_padding)
         code = "\n".join(
             c[skip:] if c.startswith(add_padding) else c
@@ -58,9 +66,11 @@ def _format_block(
     else:
         formatted_code_lines = formatted.splitlines()
         padding = input_block[0][2]
+        sql_prefix = input_block[0][3] or ""
+
         if is_doctest:
             formatted_lines = [
-                f"{padding}>>> {formatted_code_lines[0]}",
+                f"{padding}{sql_prefix}>>> {formatted_code_lines[0]}",
                 *(
                     f"{padding}...{' ' if fcl else ''}{fcl}"
                     for fcl in formatted_code_lines[1:]
@@ -68,8 +78,11 @@ def _format_block(
             ]
         else:
             formatted_lines = [
-                f"{padding}{add_padding}{fcl}" if fcl else fcl
-                for fcl in formatted_code_lines
+                f"{padding}{add_padding}{sql_prefix}{formatted_code_lines[0]}",
+                *(
+                    f"{padding}{add_padding}{fcl}" if fcl else fcl
+                    for fcl in formatted_code_lines[1:]
+                ),
             ]
             if not input_block[-1][0] and formatted_lines[-1]:
                 # last line was empty and black removed it. restore it
@@ -79,8 +92,10 @@ def _format_block(
 
 format_directive = re.compile(r"^\.\.\s*format\s*:\s*(on|off)\s*$")
 
-doctest_code_start = re.compile(r"^(\s+)>>>\s?(.+)")
+doctest_code_start = re.compile(r"^(\s+)({(?:opensql|sql|stop)})?>>>\s?(.+)")
 doctest_code_continue = re.compile(r"^\s+\.\.\.\s?(\s*.*)")
+sql_code_start = re.compile(r"^(\s+){(?:open)?sql}")
+sql_code_stop = re.compile(r"^(\s+){stop}")
 
 start_code_section = re.compile(
     r"^(((?!\.\.).+::)|(\.\.\s*sourcecode::(.*py.*)?)|(::))$"
@@ -101,6 +116,7 @@ def format_file(
     plain_code_section = False
     plain_padding = None
     plain_padding_len = None
+    sql_section = False
 
     errors = []
 
@@ -117,6 +133,7 @@ def format_file(
                 )
                 plain_block = None
             plain_code_section = True
+            assert not sql_section
             plain_padding = start_space.match(line).groups()[0]
             plain_padding_len = len(plain_padding)
             buffer.append(line)
@@ -126,14 +143,16 @@ def format_file(
             and line.strip()
             and not line.startswith(" " * (plain_padding_len + 1))
         ):
-            plain_code_section = False
+            plain_code_section = sql_section = False
         elif match := format_directive.match(line):
             disable_format = match.groups()[0] == "off"
 
         if doctest_block:
             assert not plain_block
             if match := doctest_code_continue.match(line):
-                doctest_block.append((line, line_no, None, match.groups()[0]))
+                doctest_block.append(
+                    (line, line_no, None, None, match.groups()[0])
+                )
                 continue
             else:
                 buffer.extend(
@@ -143,9 +162,13 @@ def format_file(
                 )
                 doctest_block = None
         elif plain_block:
-            if plain_code_section and not doctest_code_start.match(line):
+            if (
+                plain_code_section
+                and not doctest_code_start.match(line)
+                and not sql_code_start.match(line)
+            ):
                 plain_block.append(
-                    (line, line_no, None, line[plain_padding_len:])
+                    (line, line_no, None, None, line[plain_padding_len:])
                 )
                 continue
             else:
@@ -157,7 +180,7 @@ def format_file(
                 plain_block = None
 
         if line and (match := doctest_code_start.match(line)):
-            plain_code_section = False
+            plain_code_section = sql_section = False
             if plain_block:
                 buffer.extend(
                     _format_block(
@@ -165,15 +188,53 @@ def format_file(
                     )
                 )
                 plain_block = None
-            padding, code = match.groups()
-            doctest_block = [(line, line_no, padding, code)]
+            padding, code = match.group(1, 3)
+            doctest_block = [(line, line_no, padding, match.group(2), code)]
+        elif (
+            line
+            and plain_code_section
+            and (match := sql_code_start.match(line))
+        ):
+            if plain_block:
+                buffer.extend(
+                    _format_block(
+                        plain_block, exit_on_error, errors, is_doctest=False
+                    )
+                )
+                plain_block = None
+
+            sql_section = True
+            buffer.append(line)
+        elif line and sql_section and (match := sql_code_stop.match(line)):
+            sql_section = False
+            orig_line = line
+            line = line.replace("{stop}", "")
+            assert not doctest_block
+            # start of a plain block
+            if line.strip():
+                plain_block = [
+                    (
+                        line,
+                        line_no,
+                        plain_padding,
+                        "{stop}",
+                        line[plain_padding_len:],
+                    )
+                ]
+            else:
+                buffer.append(orig_line)
+
         elif (
-            line and not no_plain and not disable_format and plain_code_section
+            line
+            and not no_plain
+            and not disable_format
+            and plain_code_section
+            and not sql_section
         ):
             assert not doctest_block
             # start of a plain block
             plain_block = [
-                (line, line_no, plain_padding, line[plain_padding_len:])
+                (line, line_no, plain_padding, None, line[plain_padding_len:])
             ]
         else:
             buffer.append(line)
@@ -320,6 +381,7 @@ Another alterative is to use less than 4 spaces to indent the code block.
         target_versions=set(
             TargetVersion[val.upper()]
             for val in config.get("target_version", [])
+            if val != "py27"
         ),
         line_length=config.get("line_length", DEFAULT_LINE_LENGTH)
         if args.project_line_length