]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure type lengths are int in oracle
authorFederico Caselli <cfederico87@gmail.com>
Sun, 26 Jun 2022 10:31:45 +0000 (12:31 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sun, 26 Jun 2022 14:35:12 +0000 (16:35 +0200)
Repair change introduced by the multi reflection that caused
char length of varchar like types or precisions in numberic
like types to be set as float.
This will fix the test errors in alembic that are
currently broken, as shown in
I9ad803df1d3ccf2a5111266b781061936717b8c8

Change-Id: Idd5975efaeadfe6327a1cd3b6667d82e836a2cb1

lib/sqlalchemy/dialects/oracle/base.py
test/dialect/oracle/test_reflection.py

index fee0988895cfca968a38483b83f2e905d074bfe6..6e40e4df21cd11c9a731011a488ea99a1f44b426 100644 (file)
@@ -2238,15 +2238,21 @@ class OracleDialect(default.DefaultDialect):
             all_objects=all_objects,
         )
 
+        def maybe_int(value):
+            if isinstance(value, float) and value.is_integer():
+                return int(value)
+            else:
+                return value
+
         for row_dict in result:
             table_name = self.normalize_name(row_dict["table_name"])
             orig_colname = row_dict["column_name"]
             colname = self.normalize_name(orig_colname)
             coltype = row_dict["data_type"]
-            precision = row_dict["data_precision"]
+            precision = maybe_int(row_dict["data_precision"])
 
             if coltype == "NUMBER":
-                scale = row_dict["data_scale"]
+                scale = maybe_int(row_dict["data_scale"])
                 if precision is None and scale == 0:
                     coltype = INTEGER()
                 else:
@@ -2266,9 +2272,8 @@ class OracleDialect(default.DefaultDialect):
                     coltype = FLOAT(binary_precision=precision)
 
             elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
-                coltype = self.ischema_names.get(coltype)(
-                    row_dict["char_length"]
-                )
+                char_length = maybe_int(row_dict["char_length"])
+                coltype = self.ischema_names.get(coltype)(char_length)
             elif "WITH TIME ZONE" in coltype:
                 coltype = TIMESTAMP(timezone=True)
             else:
index 53eb94df306c79d537632ecab63269bce5f7c267..901db9f4e800a1694cddc144635ae37e56230dd9 100644 (file)
@@ -1,6 +1,7 @@
 # coding: utf-8
 
 
+from sqlalchemy import CHAR
 from sqlalchemy import Double
 from sqlalchemy import exc
 from sqlalchemy import FLOAT
@@ -14,9 +15,11 @@ from sqlalchemy import inspect
 from sqlalchemy import INTEGER
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
+from sqlalchemy import NCHAR
 from sqlalchemy import Numeric
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import select
+from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import Unicode
@@ -27,7 +30,11 @@ from sqlalchemy.dialects.oracle.base import BINARY_FLOAT
 from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION
 from sqlalchemy.dialects.oracle.base import NUMBER
 from sqlalchemy.dialects.oracle.base import REAL
+from sqlalchemy.dialects.oracle.types import NVARCHAR2
+from sqlalchemy.dialects.oracle.types import VARCHAR2
 from sqlalchemy.engine import ObjectKind
+from sqlalchemy.sql.sqltypes import NVARCHAR
+from sqlalchemy.sql.sqltypes import VARCHAR
 from sqlalchemy.testing import assert_warns
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import config
@@ -1006,20 +1013,22 @@ class TypeReflectionTest(fixtures.TestBase):
         for i, (reflected_col, spec) in enumerate(zip(table.c, specs)):
             expected_spec = spec[1]
             reflected_type = reflected_col.type
-            is_(type(reflected_type), type(expected_spec))
+            is_(type(reflected_type), type(expected_spec), spec[0])
             for attr in attributes:
+                r_attr = getattr(reflected_type, attr)
+                e_attr = getattr(expected_spec, attr)
+                col = f"c{i+1}"
                 eq_(
-                    getattr(reflected_type, attr),
-                    getattr(expected_spec, attr),
-                    "Column %s: Attribute %s value of %s does not "
-                    "match %s for type %s"
-                    % (
-                        "c%i" % (i + 1),
-                        attr,
-                        getattr(reflected_type, attr),
-                        getattr(expected_spec, attr),
-                        spec[0],
-                    ),
+                    r_attr,
+                    e_attr,
+                    f"Column {col}: Attribute {attr} value of {r_attr} "
+                    f"does not match {e_attr} for type {spec[0]}",
+                )
+                eq_(
+                    type(r_attr),
+                    type(e_attr),
+                    f"Column {col}: Attribute {attr} type do not match "
+                    f"{type(r_attr)} != {type(e_attr)} for db type {spec[0]}",
                 )
 
     def test_integer_types(self, metadata, connection):
@@ -1061,6 +1070,21 @@ class TypeReflectionTest(fixtures.TestBase):
         ]
         self._run_test(metadata, connection, specs, ["precision"])
 
+    def test_string_types(
+        self,
+        metadata,
+        connection,
+    ):
+        specs = [
+            (String(125), VARCHAR(125)),
+            (String(42).with_variant(VARCHAR2(42), "oracle"), VARCHAR(42)),
+            (Unicode(125), VARCHAR(125)),
+            (Unicode(42).with_variant(NVARCHAR2(42), "oracle"), NVARCHAR(42)),
+            (CHAR(125), CHAR(125)),
+            (NCHAR(42), NCHAR(42)),
+        ]
+        self._run_test(metadata, connection, specs, ["length"])
+
 
 class IdentityReflectionTest(fixtures.TablesTest):
     __only_on__ = "oracle"