From: rusher Date: Wed, 14 Jan 2026 14:03:00 +0000 (-0500) Subject: correct mariadb sequence behavior when cycle=False X-Git-Tag: rel_2_0_46~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d317d60ae42bb63f19675c46d36e0926263a4280;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git correct mariadb sequence behavior when cycle=False Fixed the SQL compilation for the mariadb sequence "NOCYCLE" keyword that is to be emitted when the :paramref:`.Sequence.cycle` parameter is set to False on a :class:`.Sequence`. Pull request courtesy Diego Dupin. Fixes: #13073 Closes: #13074 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13074 Pull-request-sha: ead18a04018db6d574a3bc4bd71f21c23256737c Change-Id: Ie1640c969aaa64e41da334fe0eff21e0d12a8bf0 (cherry picked from commit c36643fbb933c0defd00b9caa7a184c24e2a544b) --- diff --git a/doc/build/changelog/unreleased_20/13073.rst b/doc/build/changelog/unreleased_20/13073.rst new file mode 100644 index 0000000000..2716b8154c --- /dev/null +++ b/doc/build/changelog/unreleased_20/13073.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, mariadb + :tickets: 13070 + + Fixed the SQL compilation for the mariadb sequence "NOCYCLE" keyword that + is to be emitted when the :paramref:`.Sequence.cycle` parameter is set to + False on a :class:`.Sequence`. Pull request courtesy Diego Dupin. diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index be68e961ea..e52f76bd99 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1182,6 +1182,7 @@ if TYPE_CHECKING: from ...sql.functions import random from ...sql.functions import rollup from ...sql.functions import sysdate + from ...sql.schema import IdentityOptions from ...sql.schema import Sequence as Sequence_SchemaItem from ...sql.type_api import TypeEngine from ...sql.visitors import ExternallyTraversible @@ -2330,6 +2331,15 @@ class MySQLDDLCompiler(compiler.DDLCompiler): self.get_column_specification(create.element), ) + def get_identity_options(self, identity_options: IdentityOptions) -> str: + """mariadb-specific sequence option; this will move to a + mariadb-specific module in 2.1 + + """ + text = super().get_identity_options(identity_options) + text = text.replace("NO CYCLE", "NOCYCLE") + return text + class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _extend_numeric(self, type_: _NumericType, spec: str) -> str: diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py index 23c8ac21b4..70d9cb8094 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadb.py +++ b/lib/sqlalchemy/dialects/mysql/mariadb.py @@ -8,7 +8,6 @@ from __future__ import annotations from typing import Any -from typing import Callable from .base import MariaDBIdentifierPreparer from .base import MySQLDialect @@ -51,7 +50,7 @@ class MariaDBDialect(MySQLDialect): type_compiler_cls = MariaDBTypeCompiler -def loader(driver: str) -> Callable[[], type[MariaDBDialect]]: +def loader(driver: str) -> type[MariaDBDialect]: dialect_mod = __import__( "sqlalchemy.dialects.mysql.%s" % driver ).dialects.mysql diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 806fb45c0a..68d1b804a7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -113,6 +113,7 @@ if typing.TYPE_CHECKING: from .schema import Column from .schema import Constraint from .schema import ForeignKeyConstraint + from .schema import IdentityOptions from .schema import Index from .schema import PrimaryKeyConstraint from .schema import Table @@ -6983,7 +6984,7 @@ class DDLCompiler(Compiled): def visit_drop_constraint_comment(self, drop, **kw): raise exc.UnsupportedCompilationError(self, type(drop)) - def get_identity_options(self, identity_options): + def get_identity_options(self, identity_options: IdentityOptions) -> str: text = [] if identity_options.increment is not None: text.append("INCREMENT BY %d" % identity_options.increment) diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 4364872baf..00d1a52695 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -36,6 +36,7 @@ from sqlalchemy import NVARCHAR from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema from sqlalchemy import select +from sqlalchemy import Sequence from sqlalchemy import SmallInteger from sqlalchemy import sql from sqlalchemy import String @@ -61,6 +62,7 @@ from sqlalchemy.sql import column from sqlalchemy.sql import delete from sqlalchemy.sql import table from sqlalchemy.sql import update +from sqlalchemy.sql.ddl import CreateSequence from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import literal_column from sqlalchemy.testing import assert_raises_message @@ -1078,6 +1080,36 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): ")ENGINE=InnoDB", ) + @testing.combinations( + (Sequence("foo_seq"), "CREATE SEQUENCE foo_seq"), + (Sequence("foo_seq", cycle=True), "CREATE SEQUENCE foo_seq CYCLE"), + (Sequence("foo_seq", cycle=False), "CREATE SEQUENCE foo_seq NOCYCLE"), + ( + Sequence( + "foo_seq", + start=1, + increment=2, + nominvalue=True, + nomaxvalue=True, + cycle=False, + cache=100, + ), + ( + "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 1 NO" + " MINVALUE NO MAXVALUE CACHE 100 NOCYCLE" + ), + ), + argnames="seq, expected", + ) + @testing.variation("use_mariadb", [True, False]) + def test_mariadb_sequence_behaviors(self, seq, expected, use_mariadb): + """test #13073""" + self.assert_compile( + CreateSequence(seq), + expected, + dialect="mariadb" if use_mariadb else "mysql", + ) + def test_create_table_with_partition(self): t1 = Table( "testtable", diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 890c9edbf9..a4ef21f8b8 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -10,24 +10,31 @@ from sqlalchemy import exc from sqlalchemy import false from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import schema from sqlalchemy import select +from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import true from sqlalchemy.dialects.mysql import TIMESTAMP +from sqlalchemy.sql.ddl import CreateSequence +from sqlalchemy.sql.ddl import DropSequence from sqlalchemy.testing import assert_raises from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true class IdiosyncrasyTest(fixtures.TestBase): @@ -110,6 +117,68 @@ class ServerDefaultCreateTest(fixtures.TestBase): t.create(connection) +class MariaDBSequenceTest(fixtures.TestBase): + __only_on__ = "mariadb" + __backend__ = True + + __requires__ = ("sequences",) + + @testing.fixture + def create_seq(self, connection): + seqs = set() + + def go(seq): + seqs.add(seq) + connection.execute(CreateSequence(seq)) + + yield go + + for seq in seqs: + connection.execute(DropSequence(seq, if_exists=True)) + + def test_has_sequence_and_exists_flag(self, connection, create_seq): + seq = Sequence("has_seq_test") + is_false(inspect(connection).has_sequence("has_seq_test")) + + create_seq(seq) + is_true(inspect(connection).has_sequence("has_seq_test")) + + connection.execute(CreateSequence(seq, if_not_exists=True)) + + connection.execute(DropSequence(seq)) + is_false(inspect(connection).has_sequence("has_seq_test")) + connection.execute(DropSequence(seq, if_exists=True)) + + @testing.combinations( + (Sequence("foo_seq"), (1, 2, 3, 4, 5, 6, 7), False), + ( + Sequence("foo_seq", maxvalue=3, cycle=True), + (1, 2, 3, 1, 2, 3, 1), + False, + ), + (Sequence("foo_seq", maxvalue=3, cycle=False), (1, 2, 3), True), + argnames="seq, expected, runout", + ) + def test_sequence_roundtrip( + self, connection, create_seq, seq, expected, runout + ): + """tests related to #13073""" + + create_seq(seq) + + eq_( + [ + connection.scalar(seq.next_value()) + for i in range(len(expected)) + ], + list(expected), + ) + + if runout: + with expect_raises_message(exc.DBAPIError, ".*has run out"): + connection.scalar(seq.next_value()) + + class MatchTest(fixtures.TablesTest): __only_on__ = "mysql", "mariadb" __backend__ = True