]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Handle PostgreSQL enums in remote schemas
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2018 16:04:14 +0000 (11:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2018 18:53:16 +0000 (13:53 -0500)
Fixed issue where a :class:`.postgresql.ENUM` or a custom domain present
in a remote schema would not be recognized within column reflection if
the name of the enum/domain or the name of the schema required quoting.
A new parsing scheme now fully parses out quoted or non-quoted tokens
including support for SQL-escaped quotes.

Fixed issue where multiple :class:`.postgresql.ENUM` objects referred to
by the same :class:`.MetaData` object would fail to be created if
multiple objects had the same name under different schema names.  The
internal memoization the Postgresql dialect uses to track if it has
created a particular :class:`.postgresql.ENUM` in the database during
a DDL creation sequence now takes schema name into account.

Fixes: #4416
Change-Id: I8cf03069e10b12f409e9b6796e24fc5850979955

doc/build/changelog/unreleased_12/4416.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_utils.py
test/dialect/postgresql/test_reflection.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_12/4416.rst b/doc/build/changelog/unreleased_12/4416.rst
new file mode 100644 (file)
index 0000000..3eac7f2
--- /dev/null
@@ -0,0 +1,19 @@
+.. change::
+   :tags: bug, postgresql
+   :tickets: 4416
+
+   Fixed issue where a :class:`.postgresql.ENUM` or a custom domain present
+   in a remote schema would not be recognized within column reflection if
+   the name of the enum/domain or the name of the schema required quoting.
+   A new parsing scheme now fully parses out quoted or non-quoted tokens
+   including support for SQL-escaped quotes.
+
+.. change::
+   :tags: bug, postgresql
+
+   Fixed issue where multiple :class:`.postgresql.ENUM` objects referred to
+   by the same :class:`.MetaData` object would fail to be created if
+   multiple objects had the same name under different schema names.  The
+   internal memoization the Postgresql dialect uses to track if it has
+   created a particular :class:`.postgresql.ENUM` in the database during
+   a DDL creation sequence now takes schema name into account.
index ce809db9f7f656330b530dc6124b70f363e9c579..d68ab8ef58c977fc20086eff937db63af6b87fff 100644 (file)
@@ -1344,8 +1344,8 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
                 pg_enums = ddl_runner.memo['_pg_enums']
             else:
                 pg_enums = ddl_runner.memo['_pg_enums'] = set()
-            present = self.name in pg_enums
-            pg_enums.add(self.name)
+            present = (self.schema, self.name) in pg_enums
+            pg_enums.add((self.schema, self.name))
             return present
         else:
             return False
@@ -2580,20 +2580,26 @@ class PGDialect(default.DefaultDialect):
                      )
         c = connection.execute(s, table_oid=table_oid)
         rows = c.fetchall()
+
+        # dictionary with (name, ) if default search path or (schema, name)
+        # as keys
         domains = self._load_domains(connection)
+
+        # dictionary with (name, ) if default search path or (schema, name)
+        # as keys
         enums = dict(
-            (
-                "%s.%s" % (rec['schema'], rec['name'])
-                if not rec['visible'] else rec['name'], rec) for rec in
-            self._load_enums(connection, schema='*')
+            ((rec['name'], ), rec)
+            if rec['visible'] else ((rec['schema'], rec['name']), rec)
+            for rec in self._load_enums(connection, schema='*')
         )
 
         # format columns
         columns = []
-        for name, format_type, default, notnull, attnum, table_oid, \
+
+        for name, format_type, default_, notnull, attnum, table_oid, \
                 comment in rows:
             column_info = self._get_column_info(
-                name, format_type, default, notnull, domains, enums,
+                name, format_type, default_, notnull, domains, enums,
                 schema, comment)
             columns.append(column_info)
         return columns
@@ -2602,7 +2608,8 @@ class PGDialect(default.DefaultDialect):
                          notnull, domains, enums, schema, comment):
         def _handle_array_type(attype):
             return (
-                attype.replace('[]', ''), # strip '[]' from integer[], etc.
+                # strip '[]' from integer[], etc.
+                re.sub(r'\[\]$', '', attype),
                 attype.endswith('[]'),
             )
 
@@ -2610,12 +2617,12 @@ class PGDialect(default.DefaultDialect):
         # with time zone, geometry(POLYGON), etc.
         attype = re.sub(r'\(.*\)', '', format_type)
 
-        # strip quotes from case sensitive enum names
-        attype = re.sub(r'^"|"$', '', attype)
-
         # strip '[]' from integer[], etc. and check if an array
         attype, is_array = _handle_array_type(attype)
 
+        # strip quotes from case sensitive enum or domain names
+        enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
         nullable = not notnull
 
         charlen = re.search(r'\(([\d,]+)\)', format_type)
@@ -2668,21 +2675,24 @@ class PGDialect(default.DefaultDialect):
             args = (int(charlen),)
 
         while True:
+            # looping here to suit nested domains
             if attype in self.ischema_names:
                 coltype = self.ischema_names[attype]
                 break
-            elif attype in enums:
-                enum = enums[attype]
+            elif enum_or_domain_key in enums:
+                enum = enums[enum_or_domain_key]
                 coltype = ENUM
                 kwargs['name'] = enum['name']
                 if not enum['visible']:
                     kwargs['schema'] = enum['schema']
                 args = tuple(enum['labels'])
                 break
-            elif attype in domains:
-                domain = domains[attype]
+            elif enum_or_domain_key in domains:
+                domain = domains[enum_or_domain_key]
                 attype = domain['attype']
                 attype, is_array = _handle_array_type(attype)
+                # strip quotes from case sensitive enum or domain names
+                enum_or_domain_key = tuple(util.quoted_token_parser(attype))
                 # A table can't override whether the domain is nullable.
                 nullable = domain['nullable']
                 if domain['default'] and not default:
@@ -3166,16 +3176,16 @@ class PGDialect(default.DefaultDialect):
         for domain in c.fetchall():
             # strip (30) from character varying(30)
             attype = re.search(r'([^\(]+)', domain['attype']).group(1)
+            # 'visible' just means whether or not the domain is in a
+            # schema that's on the search path -- or not overridden by
+            # a schema with higher precedence. If it's not visible,
+            # it will be prefixed with the schema-name when it's used.
             if domain['visible']:
-                # 'visible' just means whether or not the domain is in a
-                # schema that's on the search path -- or not overridden by
-                # a schema with higher precedence. If it's not visible,
-                # it will be prefixed with the schema-name when it's used.
-                name = domain['name']
+                key = (domain['name'], )
             else:
-                name = "%s.%s" % (domain['schema'], domain['name'])
+                key = (domain['schema'], domain['name'])
 
-            domains[name] = {
+            domains[key] = {
                 'attype': attype,
                 'nullable': domain['nullable'],
                 'default': domain['default']
index 031376d7809241dc68e4c0f5015bf831384ac0fa..9229d079715a33481ab6618fb4a1779de611b1f3 100644 (file)
@@ -34,7 +34,7 @@ from .langhelpers import iterate_attributes, class_hierarchy, \
     classproperty, set_creation_order, warn_exception, warn, NoneType,\
     constructor_copy, methods_equivalent, chop_traceback, asint,\
     generic_repr, counter, PluginLoader, hybridproperty, hybridmethod, \
-    safe_reraise,\
+    safe_reraise, quoted_token_parser,\
     get_callable_argspec, only_once, attrsetter, ellipses_string, \
     warn_limited, map_bits, MemoizedSlots, EnsureKWArgType, wrap_callable
 
index 8815ed8378a206a8e9264b1216a1c748da4813ca..6a1db7a9886150b77068c2dde04a2b8d5100c4c1 100644 (file)
@@ -1422,3 +1422,48 @@ def wrap_callable(wrapper, fn):
             _f.__doc__ = fn.__doc__
 
         return _f
+
+
+def quoted_token_parser(value):
+    """Parse a dotted identifier with accomodation for quoted names.
+
+    Includes support for SQL-style double quotes as a literal character.
+
+    E.g.::
+
+        >>> quoted_token_parser("name")
+        ["name"]
+        >>> quoted_token_parser("schema.name")
+        ["schema", "name"]
+        >>> quoted_token_parser('"Schema"."Name"')
+        ['Schema', 'Name']
+        >>> quoted_token_parser('"Schema"."Name""Foo"')
+        ['Schema', 'Name""Foo']
+
+    """
+
+    if '"' not in value:
+        return value.split(".")
+
+    # 0 = outside of quotes
+    # 1 = inside of quotes
+    state = 0
+    result = [[]]
+    idx = 0
+    lv = len(value)
+    while idx < lv:
+        char = value[idx]
+        if char == '"':
+            if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
+                result[-1].append('"')
+                idx += 1
+            else:
+                state ^= 1
+        elif char == "." and state == 0:
+            result.append([])
+        else:
+            result[-1].append(char)
+        idx += 1
+
+    return ["".join(token) for token in result]
+
index 4f462c052431b12431660dcab76be5da07265e9d..bf65d4fc97e3a38c193fb38fb7b0f95f2b99fd85 100644 (file)
@@ -2373,3 +2373,122 @@ class TestProperties(fixtures.TestBase):
 
             eq_(props._data, p._data)
             eq_(props.keys(), p.keys())
+
+
+class QuotedTokenParserTest(fixtures.TestBase):
+    def _test(self, string, expected):
+        eq_(
+            langhelpers.quoted_token_parser(string),
+            expected
+        )
+
+    def test_single(self):
+        self._test(
+            "name",
+            ["name"]
+        )
+
+    def test_dotted(self):
+        self._test(
+            "schema.name", ["schema", "name"]
+        )
+
+    def test_dotted_quoted_left(self):
+        self._test(
+            '"Schema".name', ["Schema", "name"]
+        )
+
+    def test_dotted_quoted_left_w_quote_left_edge(self):
+        self._test(
+            '"""Schema".name', ['"Schema', "name"]
+        )
+
+    def test_dotted_quoted_left_w_quote_right_edge(self):
+        self._test(
+            '"Schema""".name', ['Schema"', "name"]
+        )
+
+    def test_dotted_quoted_left_w_quote_middle(self):
+        self._test(
+            '"Sch""ema".name', ['Sch"ema', "name"]
+        )
+
+    def test_dotted_quoted_right(self):
+        self._test(
+            'schema."SomeName"', ["schema", "SomeName"]
+        )
+
+    def test_dotted_quoted_right_w_quote_left_edge(self):
+        self._test(
+            'schema."""name"', ['schema', '"name']
+        )
+
+    def test_dotted_quoted_right_w_quote_right_edge(self):
+        self._test(
+            'schema."name"""', ['schema', 'name"']
+        )
+
+    def test_dotted_quoted_right_w_quote_middle(self):
+        self._test(
+            'schema."na""me"', ['schema', 'na"me']
+        )
+
+    def test_quoted_single_w_quote_left_edge(self):
+        self._test(
+            '"""name"', ['"name']
+        )
+
+    def test_quoted_single_w_quote_right_edge(self):
+        self._test(
+            '"name"""', ['name"']
+        )
+
+    def test_quoted_single_w_quote_middle(self):
+        self._test(
+            '"na""me"', ['na"me']
+        )
+
+    def test_dotted_quoted_left_w_dot_left_edge(self):
+        self._test(
+            '".Schema".name', ['.Schema', "name"]
+        )
+
+    def test_dotted_quoted_left_w_dot_right_edge(self):
+        self._test(
+            '"Schema.".name', ['Schema.', "name"]
+        )
+
+    def test_dotted_quoted_left_w_dot_middle(self):
+        self._test(
+            '"Sch.ema".name', ['Sch.ema', "name"]
+        )
+
+    def test_dotted_quoted_right_w_dot_left_edge(self):
+        self._test(
+            'schema.".name"', ['schema', '.name']
+        )
+
+    def test_dotted_quoted_right_w_dot_right_edge(self):
+        self._test(
+            'schema."name."', ['schema', 'name.']
+        )
+
+    def test_dotted_quoted_right_w_dot_middle(self):
+        self._test(
+            'schema."na.me"', ['schema', 'na.me']
+        )
+
+    def test_quoted_single_w_dot_left_edge(self):
+        self._test(
+            '".name"', ['.name']
+        )
+
+    def test_quoted_single_w_dot_right_edge(self):
+        self._test(
+            '"name."', ['name.']
+        )
+
+    def test_quoted_single_w_dot_middle(self):
+        self._test(
+            '"na.me"', ['na.me']
+        )
index 2a9887e0e15475f04d7257a8775959908d434ac8..5c4214430f8ab6ed30371de0b4a8abce287a54bd 100644 (file)
@@ -16,6 +16,8 @@ from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.dialects.postgresql import ARRAY, INTERVAL, INTEGER, TSRANGE
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 import re
+from operator import itemgetter
+import itertools
 
 
 class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
@@ -217,12 +219,15 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
     @classmethod
     def setup_class(cls):
         con = testing.db.connect()
-        for ddl in \
-                'CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42', \
-                'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0', \
-                "CREATE TYPE testtype AS ENUM ('test')", \
-                'CREATE DOMAIN enumdomain AS testtype', \
-                'CREATE DOMAIN arraydomain AS INTEGER[]':
+        for ddl in [
+            'CREATE SCHEMA "SomeSchema"',
+            'CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42',
+            'CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0',
+            "CREATE TYPE testtype AS ENUM ('test')",
+            'CREATE DOMAIN enumdomain AS testtype',
+            'CREATE DOMAIN arraydomain AS INTEGER[]',
+            'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0'
+        ]:
             try:
                 con.execute(ddl)
             except exc.DBAPIError as e:
@@ -240,12 +245,17 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
 
         con.execute('CREATE TABLE array_test (id integer, data arraydomain)')
 
+        con.execute(
+            'CREATE TABLE quote_test '
+            '(id integer, data "SomeSchema"."Quoted.Domain")')
+
     @classmethod
     def teardown_class(cls):
         con = testing.db.connect()
         con.execute('DROP TABLE testtable')
         con.execute('DROP TABLE test_schema.testtable')
         con.execute('DROP TABLE crosschema')
+        con.execute('DROP TABLE quote_test')
         con.execute('DROP DOMAIN testdomain')
         con.execute('DROP DOMAIN test_schema.testdomain')
         con.execute("DROP TABLE enum_test")
@@ -253,6 +263,8 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
         con.execute("DROP TYPE testtype")
         con.execute('DROP TABLE array_test')
         con.execute('DROP DOMAIN arraydomain')
+        con.execute('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
+        con.execute('DROP SCHEMA "SomeSchema"')
 
     def test_table_is_reflected(self):
         metadata = MetaData(testing.db)
@@ -289,6 +301,14 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             INTEGER
         )
 
+    def test_quoted_remote_schema_domain_is_reflected(self):
+        metadata = MetaData(testing.db)
+        table = Table('quote_test', metadata, autoload=True)
+        eq_(
+            table.c.data.type.__class__,
+            INTEGER
+        )
+
     def test_table_is_reflected_test_schema(self):
         metadata = MetaData(testing.db)
         table = Table('testtable', metadata, autoload=True,
@@ -972,38 +992,85 @@ class ReflectionTest(fixtures.TestBase):
 
     @testing.provide_metadata
     def test_inspect_enums_case_sensitive(self):
-        enum_type = postgresql.ENUM(
-            'CapsOne', 'CapsTwo', name='UpperCase', metadata=self.metadata)
-        enum_type.create(testing.db)
-        inspector = reflection.Inspector.from_engine(testing.db)
-        eq_(inspector.get_enums(), [
-            {
-                'visible': True,
-                'labels': ['CapsOne', 'CapsTwo'],
-                'name': 'UpperCase',
-                'schema': 'public'
-            }])
+        sa.event.listen(
+            self.metadata, "before_create",
+            sa.DDL('create schema "TestSchema"'))
+        sa.event.listen(
+            self.metadata, "after_drop",
+            sa.DDL('drop schema "TestSchema" cascade'))
+
+        for enum in 'lower_case', 'UpperCase', 'Name.With.Dot':
+            for schema in None, 'test_schema', 'TestSchema':
+
+                postgresql.ENUM(
+                    'CapsOne', 'CapsTwo', name=enum,
+                    schema=schema, metadata=self.metadata)
+
+        self.metadata.create_all(testing.db)
+        inspector = inspect(testing.db)
+        for schema in None, 'test_schema', 'TestSchema':
+            eq_(sorted(
+                inspector.get_enums(schema=schema),
+                key=itemgetter("name")), [
+                {
+                    'visible': schema is None,
+                    'labels': ['CapsOne', 'CapsTwo'],
+                    'name': "Name.With.Dot",
+                    'schema': 'public' if schema is None else schema
+                },
+                {
+                    'visible': schema is None,
+                    'labels': ['CapsOne', 'CapsTwo'],
+                    'name': "UpperCase",
+                    'schema': 'public' if schema is None else schema
+                },
+                {
+                    'visible': schema is None,
+                    'labels': ['CapsOne', 'CapsTwo'],
+                    'name': "lower_case",
+                    'schema': 'public' if schema is None else schema
+                }
+            ])
 
     @testing.provide_metadata
     def test_inspect_enums_case_sensitive_from_table(self):
-        enum_type = postgresql.ENUM(
-            'CapsOne', 'CapsTwo', name='UpperCase', metadata=self.metadata)
+        sa.event.listen(
+            self.metadata, "before_create",
+            sa.DDL('create schema "TestSchema"'))
+        sa.event.listen(
+            self.metadata, "after_drop",
+            sa.DDL('drop schema "TestSchema" cascade'))
 
-        t = Table('t', self.metadata, Column('q', enum_type))
+        counter = itertools.count()
+        for enum in 'lower_case', 'UpperCase', 'Name.With.Dot':
+            for schema in None, 'test_schema', 'TestSchema':
 
-        enum_type.create(testing.db)
-        t.create(testing.db)
+                    enum_type = postgresql.ENUM(
+                        'CapsOne', 'CapsTwo', name=enum,
+                        metadata=self.metadata, schema=schema)
 
-        inspector = reflection.Inspector.from_engine(testing.db)
-        cols = inspector.get_columns("t")
-        cols[0]['type'] = (cols[0]['type'].name, cols[0]['type'].enums)
-        eq_(cols, [
-            {
-                'name': 'q',
-                'type': ('UpperCase', ['CapsOne', 'CapsTwo']),
-                'nullable': True, 'default': None,
-                'autoincrement': False, 'comment': None}
-        ])
+                    Table(
+                        't%d' % next(counter),
+                        self.metadata, Column('q', enum_type))
+
+        self.metadata.create_all(testing.db)
+
+        inspector = inspect(testing.db)
+        counter = itertools.count()
+        for enum in 'lower_case', 'UpperCase', 'Name.With.Dot':
+            for schema in None, 'test_schema', 'TestSchema':
+                cols = inspector.get_columns("t%d" % next(counter))
+                cols[0]['type'] = (
+                    cols[0]['type'].schema,
+                    cols[0]['type'].name, cols[0]['type'].enums)
+                eq_(cols, [
+                    {
+                        'name': 'q',
+                        'type': (
+                            schema, enum, ['CapsOne', 'CapsTwo']),
+                        'nullable': True, 'default': None,
+                        'autoincrement': False, 'comment': None}
+                ])
 
     @testing.provide_metadata
     def test_inspect_enums_star(self):
index f5108920db4208baf8b9de9f13b7e0fd3fb1807e..2ea7d3024f165e039274acd66590c7a1d50725fa 100644 (file)
@@ -350,6 +350,27 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         t1.create()  # does not create ENUM
         t2.create()  # does not create ENUM
 
+    @testing.provide_metadata
+    def test_generate_multiple_schemaname_on_metadata(self):
+        metadata = self.metadata
+
+        Enum('one', 'two', 'three', name="myenum", metadata=metadata)
+        Enum('one', 'two', 'three', name="myenum", metadata=metadata,
+             schema="test_schema")
+
+        metadata.create_all(checkfirst=False)
+        assert 'myenum' in [
+            e['name'] for e in inspect(testing.db).get_enums()]
+        assert 'myenum' in [
+            e['name'] for
+            e in inspect(testing.db).get_enums(schema="test_schema")]
+        metadata.drop_all(checkfirst=False)
+        assert 'myenum' not in [
+            e['name'] for e in inspect(testing.db).get_enums()]
+        assert 'myenum' not in [
+            e['name'] for
+            e in inspect(testing.db).get_enums(schema="test_schema")]
+
     @testing.provide_metadata
     def test_drops_on_table(self):
         metadata = self.metadata