]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added support for rendering ``SMALLSERIAL`` when a :class:`.SmallInteger`
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Oct 2013 23:06:21 +0000 (19:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Oct 2013 23:06:21 +0000 (19:06 -0400)
type is used on a primary key autoincrement column, based on server
version detection of Postgresql version 9.2 or greater.
[ticket:2840]

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_dialect.py

index 00b6e1fb3a0174a744619103746e51988185a7d2..4e7b4e6f45bb12fee6fcb781027f80b741e82a57 100644 (file)
 .. changelog::
     :version: 0.9.0
 
+    .. change::
+        :tags: feature, postgresql
+        :tickets: 2840
+
+        Added support for rendering ``SMALLSERIAL`` when a :class:`.SmallInteger`
+        type is used on a primary key autoincrement column, based on server
+        version detection of Postgresql version 9.2 or greater.
+
     .. change::
         :tags: feature, mysql
         :tickets: 2817
index 06ee8c3a23c82b2874aa72e314d7e43a016f2105..5efa2e983f26b423c6a0ab758828b049d7133d83 100644 (file)
@@ -1039,12 +1039,15 @@ class PGCompiler(compiler.SQLCompiler):
 
 class PGDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
+
         colspec = self.preparer.format_column(column)
         impl_type = column.type.dialect_impl(self.dialect)
         if column.primary_key and \
             column is column.table._autoincrement_column and \
-            not isinstance(impl_type, sqltypes.SmallInteger) and \
             (
+                self.dialect.supports_smallserial or
+                not isinstance(impl_type, sqltypes.SmallInteger)
+            ) and (
                 column.default is None or
                 (
                     isinstance(column.default, schema.Sequence) and
@@ -1052,6 +1055,8 @@ class PGDDLCompiler(compiler.DDLCompiler):
                 )):
             if isinstance(impl_type, sqltypes.BigInteger):
                 colspec += " BIGSERIAL"
+            elif isinstance(impl_type, sqltypes.SmallInteger):
+                colspec += " SMALLSERIAL"
             else:
                 colspec += " SERIAL"
         else:
@@ -1330,6 +1335,7 @@ class PGDialect(default.DefaultDialect):
 
     supports_native_enum = True
     supports_native_boolean = True
+    supports_smallserial = True
 
     supports_sequences = True
     sequences_optional = True
@@ -1370,6 +1376,10 @@ class PGDialect(default.DefaultDialect):
             # psycopg2, others may have placed ENUM here as well
             self.colspecs.pop(ENUM, None)
 
+        # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
+        self.supports_smallserial = self.server_version_info >= (9, 2)
+
+
     def on_connect(self):
         if self.isolation_level is not None:
             def connect(conn):
index 3d48230f3404e4f95205b0faa9a05853fc340d65..aa11662a06ab86797f474bcef97d0c0c3a868678 100644 (file)
@@ -203,17 +203,30 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         assert_raises(exc.InvalidRequestError, testing.db.execute, stmt)
 
     def test_serial_integer(self):
-        for type_, expected in [
-            (Integer, 'SERIAL'),
-            (BigInteger, 'BIGSERIAL'),
-            (SmallInteger, 'SMALLINT'),
-            (postgresql.INTEGER, 'SERIAL'),
-            (postgresql.BIGINT, 'BIGSERIAL'),
+
+        for version, type_, expected in [
+            (None, Integer, 'SERIAL'),
+            (None, BigInteger, 'BIGSERIAL'),
+            ((9, 1), SmallInteger, 'SMALLINT'),
+            ((9, 2), SmallInteger, 'SMALLSERIAL'),
+            (None, postgresql.INTEGER, 'SERIAL'),
+            (None, postgresql.BIGINT, 'BIGSERIAL'),
         ]:
             m = MetaData()
 
             t = Table('t', m, Column('c', type_, primary_key=True))
-            ddl_compiler = testing.db.dialect.ddl_compiler(testing.db.dialect, schema.CreateTable(t))
+
+            if version:
+                dialect = postgresql.dialect()
+                dialect._get_server_version_info = Mock(return_value=version)
+                dialect.initialize(testing.db.connect())
+            else:
+                dialect = testing.db.dialect
+
+            ddl_compiler = dialect.ddl_compiler(
+                                dialect,
+                                schema.CreateTable(t)
+                            )
             eq_(
                 ddl_compiler.get_column_specification(t.c.c),
                 "c %s NOT NULL" % expected