Added "add()" and "add_all()" to scoped_session
methods. Workaround for 0.4.7::
-
+
from sqlalchemy.orm.scoping import ScopedSession, instrument
+
setattr(ScopedSession, "add", instrument("add"))
setattr(ScopedSession, "add_all", instrument("add_all"))
:tickets:
Custom collections can now specify a @converter method to translate
- objects used in "bulk" assignment into a stream of values, as in::
+ objects used in "bulk" assignment into a stream of values, as in:
+ .. sourcecode:: text
+
obj.col =
# or
obj.dictcol = {'foo': newval1, 'bar': newval2}
direct SQLite access, a ``ResultProxy``, and a simple mapped
ORM object:
-::
+.. sourcecode:: text
sqlite select/native: 0.260s
>>> column("x").startswith("total%score", autoescape=True)
-Renders as::
+Renders as:
+
+.. sourcecode:: sql
x LIKE :x_1 || '%' ESCAPE '/'
:tags: postgresql, removed
:tickets: 7258
- Removed support for multiple deprecated drivers::
+ Removed support for multiple deprecated drivers:
- pypostgresql for PostgreSQL. This is available as an
external driver at https://github.com/PyGreSQL
When we run a block like the above with logging turned on, the logging
will attempt to indicate that while a DBAPI level ``.commit()`` is called,
-it probably will have no effect due to autocommit mode::
+it probably will have no effect due to autocommit mode:
+
+.. sourcecode:: text
INFO sqlalchemy.engine.Engine BEGIN (implicit)
...
Each badge is described in more detail below.
The first statements we see for the above program will be the SQLite dialect
-checking for the existence of the "a" and "b" tables::
+checking for the existence of the "a" and "b" tables:
+
+.. sourcecode:: text
INFO sqlalchemy.engine.Engine PRAGMA temp.table_info("a")
INFO sqlalchemy.engine.Engine [raw sql] ()
In the case where the cycle cannot be resolved, such as if we hadn't applied
-a name to either constraint here, we will receive the following error::
+a name to either constraint here, we will receive the following error:
+
+.. sourcecode:: text
sqlalchemy.exc.CircularDependencyError: Can't sort tables for DROP;
an unresolvable foreign key dependency exists between tables:
:paramref:`_schema.ForeignKeyConstraint.use_alter` and
:paramref:`_schema.ForeignKey.use_alter`, when used in conjunction with a drop
operation, will require that the constraint is named, else an error
-like the following is generated::
+like the following is generated:
+
+.. sourcecode:: text
sqlalchemy.exc.CompileError: Can't emit DROP CONSTRAINT for constraint
ForeignKeyConstraint(...); it has no name
CheckConstraint("col2 > col3 + 5", name="check1"),
)
- {sql}mytable.create(engine)
- CREATE TABLE mytable (
+ mytable.create(engine)
+ {opensql}CREATE TABLE mytable (
col1 INTEGER CHECK (col1>5),
col2 INTEGER,
col3 INTEGER,
# place a unique index on col5, col6
Index("myindex", mytable.c.col5, mytable.c.col6, unique=True)
- {sql}mytable.create(engine)
- CREATE TABLE mytable (
+ mytable.create(engine)
+ {opensql}CREATE TABLE mytable (
col1 INTEGER,
col2 INTEGER,
col3 INTEGER,
.. sourcecode:: python+sql
i = Index("someindex", mytable.c.col5)
- {sql}i.create(engine)
- CREATE INDEX someindex ON mytable (col5){stop}
+ i.create(engine)
+ {opensql}CREATE INDEX someindex ON mytable (col5){stop}
.. _schema_indexes_functional:
class JSONEncodedDict(TypeDecorator):
"""Represents an immutable structure as a json-encoded string.
- Usage::
+ Usage:
JSONEncodedDict(255)
),
)
- {sql}users.create(engine)
- CREATE TABLE users (
+ users.create(engine)
+ {opensql}CREATE TABLE users (
user_id SERIAL NOT NULL,
user_name VARCHAR(40) NOT NULL,
PRIMARY KEY (user_id)
)
- select conname from pg_constraint where conname='cst_user_name_length'
- ALTER TABLE users ADD CONSTRAINT cst_user_name_length CHECK (length(user_name) >= 8){stop}
+ SELECT conname FROM pg_constraint WHERE conname='cst_user_name_length'
+ ALTER TABLE users ADD CONSTRAINT cst_user_name_length CHECK (length(user_name) >= 8)
+ {stop}
- {sql}users.drop(engine)
- select conname from pg_constraint where conname='cst_user_name_length'
+ users.drop(engine)
+ {opensql}SELECT conname FROM pg_constraint WHERE conname='cst_user_name_length'
ALTER TABLE users DROP CONSTRAINT cst_user_name_length
DROP TABLE users{stop}
.. sourcecode:: python+sql
from sqlalchemy.schema import CreateTable
+
with engine.connect() as conn:
- {sql} conn.execute(CreateTable(mytable))
- CREATE TABLE mytable (
+ conn.execute(CreateTable(mytable))
+ {opensql}CREATE TABLE mytable (
col1 INTEGER,
col2 INTEGER,
col3 INTEGER,
Column("pref_value", String(100)),
)
- {sql}metadata_obj.create_all(engine)
- PRAGMA table_info(user){}
+ metadata_obj.create_all(engine)
+ {opensql}PRAGMA table_info(user){}
CREATE TABLE user(
user_id INTEGER NOT NULL PRIMARY KEY,
user_name VARCHAR(16) NOT NULL,
Column("employee_name", String(60), nullable=False, key="name"),
Column("employee_dept", Integer, ForeignKey("departments.department_id")),
)
- {sql}employees.create(engine)
- CREATE TABLE employees(
- employee_id SERIAL NOT NULL PRIMARY KEY,
- employee_name VARCHAR(60) NOT NULL,
- employee_dept INTEGER REFERENCES departments(department_id)
+ employees.create(engine)
+ {opensql}CREATE TABLE employees(
+ employee_id SERIAL NOT NULL PRIMARY KEY,
+ employee_name VARCHAR(60) NOT NULL,
+ employee_dept INTEGER REFERENCES departments(department_id)
)
{}
.. sourcecode:: python+sql
- {sql}employees.drop(engine)
- DROP TABLE employees
+ employees.drop(engine)
+ {opensql}DROP TABLE employees
{}
To enable the "check first for the table existing" logic, add the
* :meth:`_sql.ColumnOperators.startswith`::
- The string containment operators
>>> print(column("x").startswith("word"))
x LIKE :x_1 || '%'
do_a_thing(e)
Restart the database while the script runs to demonstrate the transparent
-reconnect operation::
+reconnect operation:
+
+.. sourcecode:: text
$ python reconnect_test.py
ping: 1
:meth:`_engine.Connection.detach` method on either :class:`_engine.Connection`
or the proxied connection, which will de-associate the connection from the pool
such that it will be closed and discarded when :meth:`_engine.Connection.close`
-is called::
+is called:
+
+.. sourcecode:: text
conn = engine.connect()
conn.detach() # detaches the DBAPI connection from the connection pool
for which ``greenlet`` does not supply a `pre-built binary wheel <https://pypi.org/project/greenlet/#files>`_.
Notably, **this includes Apple M1**. To install including ``greenlet``,
add the ``asyncio`` `setuptools extra <https://packaging.python.org/en/latest/tutorials/installing-packages/#installing-setuptools-extras>`_
-to the ``pip install`` command::
+to the ``pip install`` command:
+
+.. sourcecode:: text
pip install sqlalchemy[asyncio]
ORDER BY anon_1.users_id
Depending on database specifics, there is
-a chance we may get a result like the following for the two queries::
+a chance we may get a result like the following for the two queries:
+
+.. sourcecode:: text
-- query #1
+--------+
participating in caching.
For user defined datatypes such as those which extend :class:`_types.TypeDecorator`
-and :class:`_types.UserDefinedType`, the warnings will look like::
+and :class:`_types.UserDefinedType`, the warnings will look like:
+
+.. sourcecode:: text
sqlalchemy.ext.SAWarning: MyType will not produce a cache key because the
``cache_ok`` attribute is not set to True. This can have significant
For custom and third party SQL elements, such as those constructed using
the techniques described at :ref:`sqlalchemy.ext.compiler_toplevel`, these
-warnings will look like::
+warnings will look like:
+
+.. sourcecode:: text
sqlalchemy.exc.SAWarning: Class MyClass will not make use of SQL
compilation caching as it does not set the 'inherit_cache' attribute to
False which will disable this warning.
For custom and third party dialects which make use of the :class:`.Dialect`
-class hierarchy, the warnings will look like::
+class hierarchy, the warnings will look like:
+
+.. sourcecode:: text
sqlalchemy.exc.SAWarning: Dialect database:driver will not make use of SQL
compilation caching as it does not set the 'supports_statement_cache'
session.scalars(select(FooClass).where(FooClass.somevalue == 8)).all()
The output of profiling can be used to give an idea where time is
-being spent. A section of profiling output looks like this::
+being spent. A section of profiling output looks like this:
+
+.. sourcecode:: text
13726 function calls (13042 primitive calls) in 0.014 seconds
The specifics of these calls can tell us where the time is being spent.
If for example, you see time being spent within ``cursor.execute()``,
-e.g. against the DBAPI::
+e.g. against the DBAPI:
+
+.. sourcecode:: text
2 0.102 0.102 0.204 0.102 {method 'execute' of 'sqlite3.Cursor' objects}
rows (or ``fetchmany()`` if the :meth:`_query.Query.yield_per` option is used).
An inordinately large number of rows would be indicated
-by a very slow call to ``fetchall()`` at the DBAPI level::
+by a very slow call to ``fetchall()`` at the DBAPI level:
+
+.. sourcecode:: text
2 0.300 0.600 0.300 0.600 {method 'fetchall' of 'sqlite3.Cursor' objects}
On the other hand, a fast call to ``fetchall()`` at the DBAPI level, but then
slowness when SQLAlchemy's :class:`_engine.CursorResult` is asked to do a ``fetchall()``,
may indicate slowness in processing of datatypes, such as unicode conversions
-and similar::
+and similar:
+
+.. sourcecode:: text
# the DBAPI cursor is fast...
2 0.020 0.040 0.020 0.040 {method 'fetchall' of 'sqlite3.Cursor' objects}
time.sleep(0.001)
return value
-the profiling output of this intentionally slow operation can be seen like this::
+the profiling output of this intentionally slow operation can be seen like this:
+
+.. sourcecode:: text
200 0.001 0.000 0.237 0.001 lib/sqlalchemy/sql/type_api.py:911(process)
200 0.001 0.000 0.236 0.001 test.py:28(process_result_value)
To detect slowness in ORM fetching of rows (which is the most common area
of performance concern), calls like ``populate_state()`` and ``_instance()`` will
-illustrate individual ORM object populations::
+illustrate individual ORM object populations:
+
+.. sourcecode:: text
# the ORM calls _instance for each ORM-loaded row it sees, and
# populate_state for each ORM-loaded row that results in the population
list(Iterates())
-output::
+output:
+
+.. sourcecode:: text
ITER!
LEN!
for obj in walk(a1):
print(obj)
-Output::
+Output:
+
+.. sourcecode:: text
<__main__.A object at 0x10303b190>
<__main__.B object at 0x103025210>
column("a") & column("b") & column("c") & column("d")
-would produce::
+would produce:
+
+.. sourcecode:: sql
(((a AND b) AND c) AND d)
.. sourcecode:: pycon+sql
>>> from sqlalchemy.schema import CreateTable
- {sql}>>> print(CreateTable(Vertex.__table__))
- CREATE TABLE vertices (
+ >>> print(CreateTable(Vertex.__table__))
+ {opensql}CREATE TABLE vertices (
id INTEGER NOT NULL,
x1 INTEGER NOT NULL,
y1 INTEGER NOT NULL,
>>> v = Vertex(start=Point(3, 4), end=Point(5, 6))
>>> session.add(v)
- {sql}>>> session.commit()
- BEGIN (implicit)
+ >>> session.commit()
+ {opensql}BEGIN (implicit)
INSERT INTO vertices (x1, y1, x2, y2) VALUES (?, ?, ?, ?)
[generated in ...] (3, 4, 5, 6)
COMMIT
.. sourcecode:: pycon+sql
>>> stmt = select(Vertex.start, Vertex.end)
- {sql}>>> session.execute(stmt).all()
- SELECT vertices.x1, vertices.y1, vertices.x2, vertices.y2
+ >>> session.execute(stmt).all()
+ {opensql}SELECT vertices.x1, vertices.y1, vertices.x2, vertices.y2
FROM vertices
[...] ()
{stop}[(Point(x=3, y=4), Point(x=5, y=6))]
.. sourcecode:: pycon+sql
>>> stmt = select(Vertex).where(Vertex.start == Point(3, 4)).where(Vertex.end < Point(7, 8))
- {sql}>>> session.scalars(stmt).all()
- SELECT vertices.id, vertices.x1, vertices.y1, vertices.x2, vertices.y2
+ >>> session.scalars(stmt).all()
+ {opensql}SELECT vertices.id, vertices.x1, vertices.y1, vertices.x2, vertices.y2
FROM vertices
WHERE vertices.x1 = ? AND vertices.y1 = ? AND vertices.x2 < ? AND vertices.y2 < ?
[...] (3, 4, 7, 8)
.. sourcecode:: pycon+sql
- {sql}>>> v1 = session.scalars(select(Vertex)).one()
- SELECT vertices.id, vertices.x1, vertices.y1, vertices.x2, vertices.y2
+ >>> v1 = session.scalars(select(Vertex)).one()
+ {opensql}SELECT vertices.id, vertices.x1, vertices.y1, vertices.x2, vertices.y2
FROM vertices
[...] ()
{stop}
>>> v1.end = Point(x=10, y=14)
- {sql}>>> session.commit()
- UPDATE vertices SET x2=?, y2=? WHERE vertices.id = ?
+ >>> session.commit()
+ {opensql}UPDATE vertices SET x2=?, y2=? WHERE vertices.id = ?
[...] (10, 14, 1)
COMMIT
created only after the web request begins and torn down just before the web request ends.
So it is a common practice to use :class:`.scoped_session` as a quick way
to integrate the :class:`.Session` with a web application. The sequence
-diagram below illustrates this flow::
+diagram below illustrates this flow:
+
+.. sourcecode:: text
Web Server Web Framework SQLAlchemy ORM Code
-------------- -------------- ------------------------------
>>> with Session(e) as session:
... session.add(User())
- {sql}... session.commit()
- BEGIN (implicit)
+ ... session.commit()
+ {opensql}BEGIN (implicit)
INSERT INTO user_account (created_at) VALUES (utc_timestamp())
[generated in 0.00010s] ()
COMMIT
The asyncio extension requires Python 3 only. It also depends
upon the `greenlet <https://pypi.org/project/greenlet/>`_ library. This
-dependency is installed by default on common machine platforms including::
+dependency is installed by default on common machine platforms including:
+
+.. sourcecode:: text
x86_64 aarch64 ppc64le amd64 win32
regardless of what platform is in use, the
``[asyncio]`` `setuptools extra <https://packaging.python.org/en/latest/tutorials/installing-packages/#installing-setuptools-extras>`_
may be installed
-as follows, which will include also instruct ``pip`` to install ``greenlet``::
+as follows, which will include also instruct ``pip`` to install ``greenlet``:
+
+.. sourcecode:: text
pip install sqlalchemy[asyncio]
asyncio.run(go())
-The above example prints something along the lines of::
+The above example prints something along the lines of:
+
+.. sourcecode:: text
New DBAPI connection: <AdaptedConnection <asyncpg.connection.Connection ...>>
execute from event
place on the **exterior** of SQLAlchemy's usual flow from end-user API to
DBAPI function.
- The flow of messaging may be visualized as follows::
+ The flow of messaging may be visualized as follows:
+
+ .. sourcecode:: text
SQLAlchemy SQLAlchemy SQLAlchemy SQLAlchemy plain
asyncio asyncio ORM/Core asyncio asyncio
q(s).params(id=id_).one()
The difference in Python function call count for an iteration of 10000
-calls to each block are::
+calls to each block are:
+
+.. sourcecode:: text
test_baked_query : test a baked query of the full entity.
(10000 iterations); total fn calls 1951294
test_orm_query : test a straight ORM query of the full entity.
(10000 iterations); total fn calls 7900535
-In terms of number of seconds on a powerful laptop, this comes out as::
+In terms of number of seconds on a powerful laptop, this comes out as:
+
+.. sourcecode:: text
test_baked_query : test a baked query of the full entity.
(10000 iterations); total time 2.174126 sec
The Mypy_ package itself is a dependency.
-Mypy may be installed using the "mypy" extras hook using pip::
+Mypy may be installed using the "mypy" extras hook using pip:
+
+.. sourcecode:: text
pip install sqlalchemy[mypy]
id = Column(Integer, primary_key=True)
user_id = Column(ForeignKey("user.id"))
-The plugin will deliver the message as follows::
+The plugin will deliver the message as follows:
+
+.. sourcecode:: text
$ mypy test3.py --strict
test3.py:20: error: [SQLAlchemy Mypy plugin] Can't infer type from
user = relationship(User)
-The above mapping will produce the following error::
+The above mapping will produce the following error:
+
+.. sourcecode:: text
test3.py:22: error: [SQLAlchemy Mypy plugin] Can't infer scalar or
collection for ORM mapped expression assigned to attribute 'user'
start_date: Mapped[datetime]
Above, the ``start_date`` column declared on both ``Engineer`` and ``Manager``
-will result in an error::
+will result in an error:
+
+.. sourcecode:: text
sqlalchemy.exc.ArgumentError: Column 'start_date' on class
<class '__main__.Manager'> conflicts with existing
state = mapped_column(String)
zip = mapped_column(String)
-The above mapping, when we attempt to use it, will produce the error::
+The above mapping, when we attempt to use it, will produce the error:
+
+.. sourcecode:: text
sqlalchemy.exc.AmbiguousForeignKeysError: Could not determine join
condition between parent/child tables on relationship
magazine_id = mapped_column(ForeignKey("magazine.id"), primary_key=True)
magazine = relationship("Magazine")
-When the above mapping is configured, we will see this warning emitted::
+When the above mapping is configured, we will see this warning emitted:
+
+.. sourcecode:: text
SAWarning: relationship 'Article.writer' will copy column
writer.magazine_id to column article.magazine_id,
session = Session()
- {sql}address = session.scalars(
- select(EmailAddress).where(EmailAddress.email == 'address@example.com'
+ address = session.scalars(
+ select(EmailAddress).where(EmailAddress.email == "address@example.com")
).one()
- SELECT address.email AS address_email, address.id AS address_id
+ {opensql}SELECT address.email AS address_email, address.id AS address_id
FROM address
WHERE address.email = ?
('address@example.com',)
{stop}
address.email = "otheraddress@example.com"
- {sql}session.commit()
- UPDATE address SET email=? WHERE address.id = ?
+ session.commit()
+ {opensql}UPDATE address SET email=? WHERE address.id = ?
('otheraddress@example.com', 1)
COMMIT
{stop}
.. sourcecode:: python+sql
- {sql}address = session.scalars(select(EmailAddress).where(EmailAddress.email == 'address')).one()
- SELECT address.email AS address_email, address.id AS address_id
+ address = session.scalars(
+ select(EmailAddress).where(EmailAddress.email == "address")
+ ).one()
+ {opensql}SELECT address.email AS address_email, address.id AS address_id
FROM address
WHERE substr(address.email, ?, length(address.email) - ?) = ?
(0, 12, 'address')
* Two tables each contain a foreign key referencing the other
table, with a row in each table referencing the other.
-For example::
+For example:
+
+.. sourcecode:: text
user
---------------------------------
user_id name related_user_id
1 'ed' 1
-Or::
+Or:
+
+.. sourcecode:: text
widget entry
------------------------------------------- ---------------------------------
>>> w1.favorite_entry = e1
>>> w1.entries = [e1]
>>> session.add_all([w1, e1])
- {sql}>>> session.commit()
- BEGIN (implicit)
+ >>> session.commit()
+ {opensql}BEGIN (implicit)
INSERT INTO widget (favorite_entry_id, name) VALUES (?, ?)
(None, 'somewidget')
INSERT INTO entry (widget_id, name) VALUES (?, ?)
data = mapped_column(String(50))
children = relationship("Node")
-With this structure, a graph such as the following::
+With this structure, a graph such as the following:
+
+.. sourcecode:: text
root --+---> child1
+---> child2 --+--> subchild1
| +--> subchild2
+---> child3
-Would be represented with data such as::
+Would be represented with data such as:
+
+.. sourcecode:: text
id parent_id data
--- ------- ----
session.commit() # commits
- result = session.execute(<some SELECT statement>)
+ result = session.execute("<some SELECT statement>")
# remaining transactional state from the .execute() call is
# discarded
the two packages, at the level of the :class:`_orm.sessionmaker` vs.
the :class:`_engine.Engine`, as well as the :class:`_orm.Session` vs.
the :class:`_engine.Connection`. The following sections detail
-these scenarios based on the following scheme::
+these scenarios based on the following scheme:
+.. sourcecode:: text
ORM (using future Session) Core (using future engine)
----------------------------------------- -----------------------------------
# make a specific Session that will use the "autocommit" engine
with Session(bind=autocommit_engine) as session:
# work with session
+ ...
For the case where the :class:`.Session` or :class:`.sessionmaker` is
configured with multiple "binds", we can either re-specify the ``binds``
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
from collections.abc import Iterator
+from functools import partial
from pathlib import Path
import re
+from typing import NamedTuple
from black import format_str
from black.const import DEFAULT_LINE_LENGTH
home = Path(__file__).parent.parent
-_Block = list[
- tuple[
- str,
- int,
- str | None,
- str | None,
- str,
- ]
-]
+
+class BlockLine(NamedTuple):
+ line: str
+ line_no: int
+ code: str
+ padding: str | None = None # relevant only on first line of block
+ sql_marker: str | None = None
+
+
+_Block = list[BlockLine]
def _format_block(
exit_on_error: bool,
errors: list[tuple[int, str, Exception]],
is_doctest: bool,
+ file: str,
) -> list[str]:
if not is_doctest:
# The first line may have additional padding. Remove then restore later
- add_padding = start_space.match(input_block[0][4]).groups()[0]
+ add_padding = start_space.match(input_block[0].code).groups()[0]
skip = len(add_padding)
code = "\n".join(
- c[skip:] if c.startswith(add_padding) else c
- for *_, c in input_block
+ l.code[skip:] if l.code.startswith(add_padding) else l.code
+ for l in input_block
)
else:
add_padding = None
- code = "\n".join(c for *_, c in input_block)
+ code = "\n".join(l.code for l in input_block)
try:
formatted = format_str(code, mode=BLACK_MODE)
except Exception as e:
- start_line = input_block[0][1]
+ start_line = input_block[0].line_no
+ first_error = not errors
errors.append((start_line, code, e))
- if is_doctest:
- print(
- "Could not format code block starting at "
- f"line {start_line}:\n{code}\nError: {e}"
- )
- if exit_on_error:
- print("Exiting since --exit-on-error was passed")
- raise
- else:
- print("Ignoring error")
- elif VERBOSE:
- print(
- "Could not format code block starting at "
- f"line {start_line}:\n---\n{code}\n---Error: {e}"
- )
- return [line for line, *_ in input_block]
+ type_ = "doctest" if is_doctest else "plain"
+ if first_error:
+ print() # add newline
+ print(
+ f"--- {file}:{start_line} Could not format {type_} code "
+ f"block:\n{code}\n---Error: {e}"
+ )
+ if exit_on_error:
+ print("Exiting since --exit-on-error was passed")
+ raise
+ else:
+ print("Ignoring error")
+ return [l.line for l in input_block]
else:
formatted_code_lines = formatted.splitlines()
- padding = input_block[0][2]
- sql_prefix = input_block[0][3] or ""
+ padding = input_block[0].padding
+ sql_prefix = input_block[0].sql_marker or ""
if is_doctest:
formatted_lines = [
for fcl in formatted_code_lines[1:]
),
]
- if not input_block[-1][0] and formatted_lines[-1]:
+ if not input_block[-1].line and formatted_lines[-1]:
# last line was empty and black removed it. restore it
formatted_lines.append("")
return formatted_lines
doctest_code_start = re.compile(r"^(\s+)({(?:opensql|sql|stop)})?>>>\s?(.+)")
doctest_code_continue = re.compile(r"^\s+\.\.\.\s?(\s*.*)")
-sql_code_start = re.compile(r"^(\s+){(?:open)?sql}")
+
+sql_code_start = re.compile(r"^(\s+)({(?:open)?sql})")
sql_code_stop = re.compile(r"^(\s+){stop}")
start_code_section = re.compile(
def format_file(
- file: Path, exit_on_error: bool, check: bool, no_plain: bool
+ file: Path, exit_on_error: bool, check: bool
) -> tuple[bool, int]:
buffer = []
if not check:
errors = []
+ do_doctest_format = partial(
+ _format_block,
+ exit_on_error=exit_on_error,
+ errors=errors,
+ is_doctest=True,
+ file=str(file),
+ )
+
+ def doctest_format():
+ nonlocal doctest_block
+ if doctest_block:
+ buffer.extend(do_doctest_format(doctest_block))
+ doctest_block = None
+
+ do_plain_format = partial(
+ _format_block,
+ exit_on_error=exit_on_error,
+ errors=errors,
+ is_doctest=False,
+ file=str(file),
+ )
+
+ def plain_format():
+ nonlocal plain_block
+ if plain_block:
+ buffer.extend(do_plain_format(plain_block))
+ plain_block = None
+
disable_format = False
for line_no, line in enumerate(original.splitlines(), 1):
- # start_code_section requires no spaces at the start
- if start_code_section.match(line.strip()):
- if plain_block:
- buffer.extend(
- _format_block(
- plain_block, exit_on_error, errors, is_doctest=False
- )
- )
- plain_block = None
+ if (
+ line
+ and not disable_format
+ and start_code_section.match(line.strip())
+ ):
+ # start_code_section regexp requires no spaces at the start
+ plain_format()
plain_code_section = True
assert not sql_section
plain_padding = start_space.match(line).groups()[0]
):
plain_code_section = sql_section = False
elif match := format_directive.match(line):
+ assert not plain_code_section
disable_format = match.groups()[0] == "off"
if doctest_block:
assert not plain_block
if match := doctest_code_continue.match(line):
doctest_block.append(
- (line, line_no, None, None, match.groups()[0])
+ BlockLine(line, line_no, match.groups()[0])
)
continue
else:
- buffer.extend(
- _format_block(
- doctest_block, exit_on_error, errors, is_doctest=True
- )
- )
- doctest_block = None
+ doctest_format()
elif plain_block:
if (
plain_code_section
and not sql_code_start.match(line)
):
plain_block.append(
- (line, line_no, None, None, line[plain_padding_len:])
+ BlockLine(line, line_no, line[plain_padding_len:])
)
continue
else:
- buffer.extend(
- _format_block(
- plain_block, exit_on_error, errors, is_doctest=False
- )
- )
- plain_block = None
+ plain_format()
if line and (match := doctest_code_start.match(line)):
+ # the line is in a doctest
plain_code_section = sql_section = False
- if plain_block:
- buffer.extend(
- _format_block(
- plain_block, exit_on_error, errors, is_doctest=False
- )
- )
- plain_block = None
- padding, code = match.group(1, 3)
- doctest_block = [(line, line_no, padding, match.group(2), code)]
- elif (
- line
- and plain_code_section
- and (match := sql_code_start.match(line))
- ):
- if plain_block:
- buffer.extend(
- _format_block(
- plain_block, exit_on_error, errors, is_doctest=False
- )
- )
- plain_block = None
-
- sql_section = True
- buffer.append(line)
- elif line and sql_section and (match := sql_code_stop.match(line)):
- sql_section = False
- orig_line = line
- line = line.replace("{stop}", "")
+ plain_format()
+ padding, sql_marker, code = match.groups()
+ doctest_block = [
+ BlockLine(line, line_no, code, padding, sql_marker)
+ ]
+ elif line and plain_code_section:
+ assert not disable_format
assert not doctest_block
- # start of a plain block
- if line.strip():
+ if match := sql_code_start.match(line):
+ plain_format()
+ sql_section = True
+ buffer.append(line)
+ elif sql_section:
+ if match := sql_code_stop.match(line):
+ sql_section = False
+ no_stop_line = line.replace("{stop}", "")
+ # start of a plain block
+ if no_stop_line.strip():
+ assert not plain_block
+ plain_block = [
+ BlockLine(
+ line,
+ line_no,
+ no_stop_line[plain_padding_len:],
+ plain_padding,
+ "{stop}",
+ )
+ ]
+ continue
+ buffer.append(line)
+ else:
+ # start of a plain block
+ assert not doctest_block
plain_block = [
- (
+ BlockLine(
line,
line_no,
- plain_padding,
- "{stop}",
line[plain_padding_len:],
+ plain_padding,
)
]
- else:
- buffer.append(orig_line)
-
- elif (
- line
- and not no_plain
- and not disable_format
- and plain_code_section
- and not sql_section
- ):
- assert not doctest_block
- # start of a plain block
- plain_block = [
- (line, line_no, plain_padding, None, line[plain_padding_len:])
- ]
else:
buffer.append(line)
- if doctest_block:
- buffer.extend(
- _format_block(
- doctest_block, exit_on_error, errors, is_doctest=True
- )
- )
- if plain_block:
- buffer.extend(
- _format_block(plain_block, exit_on_error, errors, is_doctest=False)
- )
+ doctest_format()
+ plain_format()
if buffer:
- # if there is nothing in the buffer something strange happened so
- # don't do anything
buffer.append("")
updated = "\n".join(buffer)
equal = original == updated
# write only if there are changes to write
file.write_text(updated, "utf-8", newline="\n")
else:
+ # if there is nothing in the buffer something strange happened so
+ # don't do anything
if not check:
print(".. Nothing to write")
equal = bool(original) is False
yield from (home / directory).glob("./**/*.rst")
-def main(
- file: str | None,
- directory: str,
- exit_on_error: bool,
- check: bool,
- no_plain: bool,
-):
+def main(file: str | None, directory: str, exit_on_error: bool, check: bool):
if file is not None:
- result = [format_file(Path(file), exit_on_error, check, no_plain)]
+ result = [format_file(Path(file), exit_on_error, check)]
else:
result = [
- format_file(doc, exit_on_error, check, no_plain)
+ format_file(doc, exit_on_error, check)
for doc in iter_files(directory)
]
else "no formatting errors reported",
)
- # interim, until we fix all formatting errors
- if not to_reformat:
- exit(0)
exit(1)
Plain code block may lead to false positive. To disable formatting on a \
file section the comment ``.. format: off`` disables formatting until \
``.. format: on`` is encountered or the file ends.
-Another alterative is to use less than 4 spaces to indent the code block.
""",
formatter_class=RawDescriptionHelpFormatter,
)
parser.add_argument(
"-e",
"--exit-on-error",
- help="Exit in case of black format error instead of ignoring it. "
- "This option is only valid for doctest code blocks",
+ help="Exit in case of black format error instead of ignoring it.",
action="store_true",
)
parser.add_argument(
"of using the black default of 88",
action="store_true",
)
- parser.add_argument(
- "-v",
- "--verbose",
- help="Increase verbosity",
- action="store_true",
- )
- parser.add_argument(
- "-n",
- "--no-plain",
- help="Disable plain code blocks formatting that's more difficult "
- "to parse compared to doctest code blocks",
- action="store_true",
- )
args = parser.parse_args()
config = parse_pyproject_toml(home / "pyproject.toml")
if args.project_line_length
else DEFAULT_LINE_LENGTH,
)
- VERBOSE = args.verbose
-
- main(
- args.file,
- args.directory,
- args.exit_on_error,
- args.check,
- args.no_plain,
- )
+
+ main(args.file, args.directory, args.exit_on_error, args.check)