from ...sql.functions import random
from ...sql.functions import rollup
from ...sql.functions import sysdate
+ from ...sql.schema import IdentityOptions
from ...sql.schema import Sequence as Sequence_SchemaItem
from ...sql.type_api import TypeEngine
from ...sql.visitors import ExternallyTraversible
self.get_column_specification(create.element),
)
+ def get_identity_options(self, identity_options: IdentityOptions) -> str:
+ """mariadb-specific sequence option; this will move to a
+ mariadb-specific module in 2.1
+
+ """
+ text = super().get_identity_options(identity_options)
+ text = text.replace("NO CYCLE", "NOCYCLE")
+ return text
+
class MySQLTypeCompiler(compiler.GenericTypeCompiler):
def _extend_numeric(self, type_: _NumericCommonType, spec: str) -> str:
from .schema import Column
from .schema import Constraint
from .schema import ForeignKeyConstraint
+ from .schema import IdentityOptions
from .schema import Index
from .schema import PrimaryKeyConstraint
from .schema import Table
def visit_drop_constraint_comment(self, drop, **kw):
raise exc.UnsupportedCompilationError(self, type(drop))
- def get_identity_options(self, identity_options):
+ def get_identity_options(self, identity_options: IdentityOptions) -> str:
text = []
if identity_options.increment is not None:
text.append("INCREMENT BY %d" % identity_options.increment)
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema
from sqlalchemy import select
+from sqlalchemy import Sequence
from sqlalchemy import SmallInteger
from sqlalchemy import sql
from sqlalchemy import String
from sqlalchemy.sql import delete
from sqlalchemy.sql import table
from sqlalchemy.sql import update
+from sqlalchemy.sql.ddl import CreateSequence
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.testing import assert_raises_message
")ENGINE=InnoDB",
)
+ @testing.combinations(
+ (Sequence("foo_seq"), "CREATE SEQUENCE foo_seq"),
+ (Sequence("foo_seq", cycle=True), "CREATE SEQUENCE foo_seq CYCLE"),
+ (Sequence("foo_seq", cycle=False), "CREATE SEQUENCE foo_seq NOCYCLE"),
+ (
+ Sequence(
+ "foo_seq",
+ start=1,
+ increment=2,
+ nominvalue=True,
+ nomaxvalue=True,
+ cycle=False,
+ cache=100,
+ ),
+ (
+ "CREATE SEQUENCE foo_seq INCREMENT BY 2 START WITH 1 NO"
+ " MINVALUE NO MAXVALUE CACHE 100 NOCYCLE"
+ ),
+ ),
+ argnames="seq, expected",
+ )
+ @testing.variation("use_mariadb", [True, False])
+ def test_mariadb_sequence_behaviors(self, seq, expected, use_mariadb):
+ """test #13073"""
+ self.assert_compile(
+ CreateSequence(seq),
+ expected,
+ dialect="mariadb" if use_mariadb else "mysql",
+ )
+
def test_create_table_with_partition(self):
t1 = Table(
"testtable",
from sqlalchemy import false
from sqlalchemy import ForeignKey
from sqlalchemy import func
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy import schema
from sqlalchemy import select
+from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import update
from sqlalchemy.dialects.mysql import limit
from sqlalchemy.dialects.mysql import TIMESTAMP
+from sqlalchemy.sql.ddl import CreateSequence
+from sqlalchemy.sql.ddl import DropSequence
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import combinations
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.fixtures import fixture_session
t.create(connection)
+class MariaDBSequenceTest(fixtures.TestBase):
+ __only_on__ = "mariadb"
+ __backend__ = True
+
+ __requires__ = ("sequences",)
+
+ @testing.fixture
+ def create_seq(self, connection):
+ seqs = set()
+
+ def go(seq):
+ seqs.add(seq)
+ connection.execute(CreateSequence(seq))
+
+ yield go
+
+ for seq in seqs:
+ connection.execute(DropSequence(seq, if_exists=True))
+
+ def test_has_sequence_and_exists_flag(self, connection, create_seq):
+ seq = Sequence("has_seq_test")
+ is_false(inspect(connection).has_sequence("has_seq_test"))
+
+ create_seq(seq)
+ is_true(inspect(connection).has_sequence("has_seq_test"))
+
+ connection.execute(CreateSequence(seq, if_not_exists=True))
+
+ connection.execute(DropSequence(seq))
+ is_false(inspect(connection).has_sequence("has_seq_test"))
+ connection.execute(DropSequence(seq, if_exists=True))
+
+ @testing.combinations(
+ (Sequence("foo_seq"), (1, 2, 3, 4, 5, 6, 7), False),
+ (
+ Sequence("foo_seq", maxvalue=3, cycle=True),
+ (1, 2, 3, 1, 2, 3, 1),
+ False,
+ ),
+ (Sequence("foo_seq", maxvalue=3, cycle=False), (1, 2, 3), True),
+ argnames="seq, expected, runout",
+ )
+ def test_sequence_roundtrip(
+ self, connection, create_seq, seq, expected, runout
+ ):
+ """tests related to #13073"""
+
+ create_seq(seq)
+
+ eq_(
+ [
+ connection.scalar(seq.next_value())
+ for i in range(len(expected))
+ ],
+ list(expected),
+ )
+
+ if runout:
+ with expect_raises_message(exc.DBAPIError, ".*has run out"):
+ connection.scalar(seq.next_value())
+
+
class MatchTest(fixtures.TablesTest):
__only_on__ = "mysql", "mariadb"
__backend__ = True