]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Repaired support for Postgresql UUID types in conjunction with
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Feb 2015 00:00:07 +0000 (19:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Feb 2015 00:03:08 +0000 (19:03 -0500)
the ARRAY type when using psycopg2.  The psycopg2 dialect now
employs use of the psycopg2.extras.register_uuid() hook
so that UUID values are always passed to/from the DBAPI as
UUID() objects.   The :paramref:`.UUID.as_uuid` flag is still
honored, except with psycopg2 we need to convert returned
UUID objects back into strings when this is disabled.
fixes #2940

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/dialects/postgresql/psycopg2.py
test/dialect/postgresql/test_types.py

index 2af1cd35f0fa7c81f4f24f8be72d2dac0591d485..10d003f09bd3653304246f3e272a1eda7c901c48 100644 (file)
 .. changelog::
     :version: 0.9.9
 
+    .. change::
+        :tags: bug, postgresql
+        :tickets: 2940
+
+        Repaired support for Postgresql UUID types in conjunction with
+        the ARRAY type when using psycopg2.  The psycopg2 dialect now
+        employs use of the psycopg2.extras.register_uuid() hook
+        so that UUID values are always passed to/from the DBAPI as
+        UUID() objects.   The :paramref:`.UUID.as_uuid` flag is still
+        honored, except with psycopg2 we need to convert returned
+        UUID objects back into strings when this is disabled.
+
     .. change::
         :tags: bug, postgresql
         :pullreq: github:145
index 26e45fed2eb27d88f350dfb323ed154cbc925b9c..4f1e04f20f610d85e048f69f5d573207522f332a 100644 (file)
@@ -312,10 +312,15 @@ from ... import types as sqltypes
 from .base import PGDialect, PGCompiler, \
     PGIdentifierPreparer, PGExecutionContext, \
     ENUM, ARRAY, _DECIMAL_TYPES, _FLOAT_TYPES,\
-    _INT_TYPES
+    _INT_TYPES, UUID
 from .hstore import HSTORE
 from .json import JSON, JSONB
 
+try:
+    from uuid import UUID as _python_UUID
+except ImportError:
+    _python_UUID = None
+
 
 logger = logging.getLogger('sqlalchemy.dialects.postgresql')
 
@@ -388,6 +393,26 @@ class _PGJSONB(JSONB):
         else:
             return super(_PGJSONB, self).result_processor(dialect, coltype)
 
+
+class _PGUUID(UUID):
+    def bind_processor(self, dialect):
+        if not self.as_uuid and dialect.use_native_uuid:
+            nonetype = type(None)
+
+            def process(value):
+                if value is not None:
+                    value = _python_UUID(value)
+                return value
+            return process
+
+    def result_processor(self, dialect, coltype):
+        if not self.as_uuid and dialect.use_native_uuid:
+            def process(value):
+                if value is not None:
+                    value = str(value)
+                return value
+            return process
+
 # When we're handed literal SQL, ensure it's a SELECT query. Since
 # 8.3, combining cursors and "FOR UPDATE" has been fine.
 SERVER_SIDE_CURSOR_RE = re.compile(
@@ -488,18 +513,20 @@ class PGDialect_psycopg2(PGDialect):
             sqltypes.Enum: _PGEnum,  # needs force_unicode
             HSTORE: _PGHStore,
             JSON: _PGJSON,
-            JSONB: _PGJSONB
+            JSONB: _PGJSONB,
+            UUID: _PGUUID
         }
     )
 
     def __init__(self, server_side_cursors=False, use_native_unicode=True,
                  client_encoding=None,
-                 use_native_hstore=True,
+                 use_native_hstore=True, use_native_uuid=True,
                  **kwargs):
         PGDialect.__init__(self, **kwargs)
         self.server_side_cursors = server_side_cursors
         self.use_native_unicode = use_native_unicode
         self.use_native_hstore = use_native_hstore
+        self.use_native_uuid = use_native_uuid
         self.supports_unicode_binds = use_native_unicode
         self.client_encoding = client_encoding
         if self.dbapi and hasattr(self.dbapi, '__version__'):
@@ -575,6 +602,11 @@ class PGDialect_psycopg2(PGDialect):
                 self.set_isolation_level(conn, self.isolation_level)
             fns.append(on_connect)
 
+        if self.dbapi and self.use_native_uuid:
+            def on_connect(conn):
+                extras.register_uuid(None, conn)
+            fns.append(on_connect)
+
         if self.dbapi and self.use_native_unicode:
             def on_connect(conn):
                 extensions.register_type(extensions.UNICODE, conn)
index 866bc7d545f9a4ec20e4d74b9c3ab7836d589e3a..36f4fdc3fef9379aa3d8d103c7eb24360486fa2c 100644 (file)
@@ -1035,7 +1035,7 @@ class UUIDTest(fixtures.TestBase):
         import uuid
         self._test_round_trip(
             Table('utable', MetaData(),
-                  Column('data', postgresql.UUID())
+                  Column('data', postgresql.UUID(as_uuid=False))
                   ),
             str(uuid.uuid4()),
             str(uuid.uuid4())
@@ -1057,13 +1057,32 @@ class UUIDTest(fixtures.TestBase):
         )
 
     @testing.fails_on('postgresql+zxjdbc',
-                      'column "data" is of type uuid[] but expression is of type character varying')
+                      'column "data" is of type uuid[] but '
+                      'expression is of type character varying')
     @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
     def test_uuid_array(self):
         import uuid
         self._test_round_trip(
-            Table('utable', MetaData(),
-                Column('data', postgresql.ARRAY(postgresql.UUID()))
+            Table(
+                'utable', MetaData(),
+                Column('data', postgresql.ARRAY(postgresql.UUID(as_uuid=True)))
+            ),
+            [uuid.uuid4(), uuid.uuid4()],
+            [uuid.uuid4(), uuid.uuid4()],
+        )
+
+    @testing.fails_on('postgresql+zxjdbc',
+                      'column "data" is of type uuid[] but '
+                      'expression is of type character varying')
+    @testing.fails_on('postgresql+pg8000', 'No support for UUID type')
+    def test_uuid_string_array(self):
+        import uuid
+        self._test_round_trip(
+            Table(
+                'utable', MetaData(),
+                Column(
+                    'data',
+                    postgresql.ARRAY(postgresql.UUID(as_uuid=False)))
             ),
             [str(uuid.uuid4()), str(uuid.uuid4())],
             [str(uuid.uuid4()), str(uuid.uuid4())],
@@ -1088,7 +1107,7 @@ class UUIDTest(fixtures.TestBase):
     def teardown(self):
         self.conn.close()
 
-    def _test_round_trip(self, utable, value1, value2):
+    def _test_round_trip(self, utable, value1, value2, exp_value2=None):
         utable.create(self.conn)
         self.conn.execute(utable.insert(), {'data': value1})
         self.conn.execute(utable.insert(), {'data': value2})
@@ -1096,7 +1115,10 @@ class UUIDTest(fixtures.TestBase):
             select([utable.c.data]).
             where(utable.c.data != value1)
         )
-        eq_(r.fetchone()[0], value2)
+        if exp_value2:
+            eq_(r.fetchone()[0], exp_value2)
+        else:
+            eq_(r.fetchone()[0], value2)
         eq_(r.fetchone(), None)