]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- an approach I like better, remove most adapt() methods and use a generic
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Dec 2010 17:44:37 +0000 (12:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Dec 2010 17:44:37 +0000 (12:44 -0500)
copier
- mssql reflection fix, but this will come in again from the tip merge

lib/sqlalchemy/dialects/access/base.py
lib/sqlalchemy/dialects/maxdb/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/langhelpers.py

index 75ea912874f0781118698fbe26697ca659337317..cf35b3e0a35282c43f3ae5cdb84fab54ce22614d 100644 (file)
@@ -50,15 +50,10 @@ class AcSmallInteger(types.SmallInteger):
         return "SMALLINT"
 
 class AcDateTime(types.DateTime):
-    def __init__(self, *a, **kw):
-        super(AcDateTime, self).__init__(False)
-
     def get_col_spec(self):
         return "DATETIME"
 
 class AcDate(types.Date):
-    def __init__(self, *a, **kw):
-        super(AcDate, self).__init__(False)
 
     def get_col_spec(self):
         return "DATETIME"
index 9a1e10f517fa87a0976426e03c41b061c90888a4..3d45bb670084c0bad185ad62e7a8829ffd244148 100644 (file)
@@ -116,15 +116,13 @@ class _StringType(sqltypes.String):
 class MaxString(_StringType):
     _type = 'VARCHAR'
 
-    def __init__(self, *a, **kw):
-        super(MaxString, self).__init__(*a, **kw)
-
 
 class MaxUnicode(_StringType):
     _type = 'VARCHAR'
 
     def __init__(self, length=None, **kw):
-        super(MaxUnicode, self).__init__(length=length, encoding='unicode')
+        kw['encoding'] = 'unicode'
+        super(MaxUnicode, self).__init__(length=length, **kw)
 
 
 class MaxChar(_StringType):
@@ -134,8 +132,8 @@ class MaxChar(_StringType):
 class MaxText(_StringType):
     _type = 'LONG'
 
-    def __init__(self, *a, **kw):
-        super(MaxText, self).__init__(*a, **kw)
+    def __init__(self, length=None, **kw):
+        super(MaxText, self).__init__(length, **kw)
 
     def get_col_spec(self):
         spec = 'LONG'
index 4c0a0089042622dfa6856a9f620480e227c69f04..c5f891fb7001694f4ae400b7884228dabfa090b4 100644 (file)
@@ -280,16 +280,15 @@ class _StringType(object):
 class TEXT(_StringType, sqltypes.TEXT):
     """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
 
-    def __init__(self, *args, **kw):
+    def __init__(self, length=None, collation=None, **kw):
         """Construct a TEXT.
 
         :param collation: Optional, a column-level collation for this string
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        collation = kw.pop('collation', None)
         _StringType.__init__(self, collation)
-        sqltypes.Text.__init__(self, *args, **kw)
+        sqltypes.Text.__init__(self, length, **kw)
 
 class NTEXT(_StringType, sqltypes.UnicodeText):
     """MSSQL NTEXT type, for variable-length unicode text up to 2^30
@@ -297,24 +296,22 @@ class NTEXT(_StringType, sqltypes.UnicodeText):
 
     __visit_name__ = 'NTEXT'
     
-    def __init__(self, *args, **kwargs):
+    def __init__(self, length=None, collation=None, **kw):
         """Construct a NTEXT.
 
         :param collation: Optional, a column-level collation for this string
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        collation = kwargs.pop('collation', None)
         _StringType.__init__(self, collation)
-        length = kwargs.pop('length', None)
-        sqltypes.UnicodeText.__init__(self, length, **kwargs)
+        sqltypes.UnicodeText.__init__(self, length, **kw)
 
 
 class VARCHAR(_StringType, sqltypes.VARCHAR):
     """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
     of 8,000 characters."""
 
-    def __init__(self, *args, **kw):
+    def __init__(self, length=None, collation=None, **kw):
         """Construct a VARCHAR.
 
         :param length: Optinal, maximum data length, in characters.
@@ -335,16 +332,15 @@ class VARCHAR(_StringType, sqltypes.VARCHAR):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        collation = kw.pop('collation', None)
         _StringType.__init__(self, collation)
-        sqltypes.VARCHAR.__init__(self, *args, **kw)
+        sqltypes.VARCHAR.__init__(self, length, **kw)
 
 class NVARCHAR(_StringType, sqltypes.NVARCHAR):
     """MSSQL NVARCHAR type.
 
     For variable-length unicode character data up to 4,000 characters."""
 
-    def __init__(self, *args, **kw):
+    def __init__(self, length=None, collation=None, **kw):
         """Construct a NVARCHAR.
 
         :param length: Optional, Maximum data length, in characters.
@@ -353,15 +349,14 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        collation = kw.pop('collation', None)
         _StringType.__init__(self, collation)
-        sqltypes.NVARCHAR.__init__(self, *args, **kw)
+        sqltypes.NVARCHAR.__init__(self, length, **kw)
 
 class CHAR(_StringType, sqltypes.CHAR):
     """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
     of 8,000 characters."""
 
-    def __init__(self, *args, **kw):
+    def __init__(self, length=None, collation=None, **kw):
         """Construct a CHAR.
 
         :param length: Optinal, maximum data length, in characters.
@@ -382,16 +377,15 @@ class CHAR(_StringType, sqltypes.CHAR):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        collation = kw.pop('collation', None)
         _StringType.__init__(self, collation)
-        sqltypes.CHAR.__init__(self, *args, **kw)
+        sqltypes.CHAR.__init__(self, length, **kw)
 
 class NCHAR(_StringType, sqltypes.NCHAR):
     """MSSQL NCHAR type.
 
     For fixed-length unicode character data up to 4,000 characters."""
 
-    def __init__(self, *args, **kw):
+    def __init__(self, length=None, collation=None, **kw):
         """Construct an NCHAR.
 
         :param length: Optional, Maximum data length, in characters.
@@ -400,9 +394,8 @@ class NCHAR(_StringType, sqltypes.NCHAR):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        collation = kw.pop('collation', None)
         _StringType.__init__(self, collation)
-        sqltypes.NCHAR.__init__(self, *args, **kw)
+        sqltypes.NCHAR.__init__(self, length, **kw)
 
 class IMAGE(sqltypes.LargeBinary):
     __visit_name__ = 'IMAGE'
@@ -1150,8 +1143,8 @@ class MSDialect(default.DefaultDialect):
                 "and sch.name=:schname "
                 "and ind.is_primary_key=0", 
                 bindparams=[
-                    sql.bindparam('tabname', tablename, sqltypes.Unicode),
-                    sql.bindparam('schname', current_schema, sqltypes.Unicode)
+                    sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)),
+                    sql.bindparam('schname', current_schema, sqltypes.String(convert_unicode=True))
                 ]
             )
         )
@@ -1163,16 +1156,19 @@ class MSDialect(default.DefaultDialect):
                 'column_names':[]
             }
         rp = connection.execute(
-            sql.text("select ind_col.index_id, col.name from sys.columns as col "
-                        "join sys.index_columns as ind_col on "
-                        "ind_col.column_id=col.column_id "
-                        "join sys.tables as tab on tab.object_id=col.object_id "
-                        "join sys.schemas as sch on sch.schema_id=tab.schema_id "
-                        "where tab.name=:tabname "
-                        "and sch.name=:schname",
+            sql.text(
+                "select ind_col.index_id, ind_col.object_id, col.name "
+                "from sys.columns as col "
+                "join sys.tables as tab on tab.object_id=col.object_id "
+                "join sys.index_columns as ind_col on "
+                "(ind_col.column_id=col.column_id and "
+                "ind_col.object_id=tab.object_id) "
+                "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+                "where tab.name=:tabname "
+                "and sch.name=:schname",
                         bindparams=[
-                            sql.bindparam('tabname', tablename, sqltypes.Unicode),
-                            sql.bindparam('schname', current_schema, sqltypes.Unicode)
+                            sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)),
+                            sql.bindparam('schname', current_schema, sqltypes.String(convert_unicode=True))
                         ]),
             )
         for row in rp:
index 5c3289bfb020107a4ac572d927e63611f02c8f36..528e949651f5fd9e1be420a63cf2369b5c3cdb9f 100644 (file)
@@ -233,17 +233,11 @@ SET_RE = re.compile(
 class _NumericType(object):
     """Base for MySQL numeric types."""
 
-    def __init__(self, **kw):
-        self.unsigned = kw.pop('unsigned', False)
-        self.zerofill = kw.pop('zerofill', False)
+    def __init__(self, unsigned=False, zerofill=False, **kw):
+        self.unsigned = unsigned
+        self.zerofill = zerofill
         super(_NumericType, self).__init__(**kw)
     
-    def adapt(self, typeimpl, **kw):
-        return super(_NumericType, self).adapt(
-                        typeimpl, 
-                        unsigned=self.unsigned, 
-                        zerofill=self.zerofill)
-        
 class _FloatType(_NumericType, sqltypes.Float):
     def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         if isinstance(self, (REAL, DOUBLE)) and \
@@ -263,11 +257,6 @@ class _IntegerType(_NumericType, sqltypes.Integer):
         self.display_width = display_width
         super(_IntegerType, self).__init__(**kw)
 
-    def adapt(self, typeimpl, **kw):
-        return super(_IntegerType, self).adapt(
-                        typeimpl, 
-                        display_width=self.display_width)
-
 class _StringType(sqltypes.String):
     """Base for MySQL string types."""
 
@@ -288,17 +277,6 @@ class _StringType(sqltypes.String):
         self.national = national
         super(_StringType, self).__init__(**kw)
     
-    def adapt(self, typeimpl, **kw):
-        return super(_StringType, self).adapt(
-            typeimpl,
-            charset=self.charset,
-            collation=self.collation,
-            ascii=self.ascii,
-            binary=self.binary,
-            national=self.national,
-            **kw
-        )
-        
     def __repr__(self):
         attributes = inspect.getargspec(self.__init__)[0][1:]
         attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
index 256972696452469d85132ff0d751ff21668f447a..3d97b504ec3162ed28386b63bd0a512bb4d4f4a3 100644 (file)
@@ -215,10 +215,6 @@ class INTERVAL(sqltypes.TypeEngine):
         return INTERVAL(day_precision=interval.day_precision,
                         second_precision=interval.second_precision)
         
-    def adapt(self, impltype):
-        return impltype(day_precision=self.day_precision, 
-                        second_precision=self.second_precision)
-
     @property
     def _type_affinity(self):
         return sqltypes.Interval
index c9920c93021383e465ab6fac4df8195f53682e09..72b58a71c16ec25894caefc061b8d95bab2f2fda 100644 (file)
@@ -133,23 +133,12 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
         super(TIMESTAMP, self).__init__(timezone=timezone)
         self.precision = precision
 
-    def adapt(self, impltype, **kw):
-        return impltype(
-                precision=self.precision, 
-                timezone=self.timezone, 
-                **kw)
         
 class TIME(sqltypes.TIME):
     def __init__(self, timezone=False, precision=None):
         super(TIME, self).__init__(timezone=timezone)
         self.precision = precision
 
-    def adapt(self, impltype, **kw):
-        return impltype(
-                precision=self.precision, 
-                timezone=self.timezone, 
-                **kw)
-    
 class INTERVAL(sqltypes.TypeEngine):
     """Postgresql INTERVAL type.
     
@@ -161,9 +150,6 @@ class INTERVAL(sqltypes.TypeEngine):
     def __init__(self, precision=None):
         self.precision = precision
     
-    def adapt(self, impltype):
-        return impltype(self.precision)
-
     @classmethod
     def _adapt_from_generic_interval(cls, interval):
         return INTERVAL(precision=interval.second_precision)
@@ -176,6 +162,9 @@ PGInterval = INTERVAL
 
 class BIT(sqltypes.TypeEngine):
     __visit_name__ = 'BIT'
+    def __init__(self, length=1):
+        self.length= length
+        
 PGBit = BIT
 
 class UUID(sqltypes.TypeEngine):
@@ -226,9 +215,6 @@ class UUID(sqltypes.TypeEngine):
         else:
             return None
     
-    def adapt(self, impltype, **kw):
-        return impltype(as_uuid=self.as_uuid, **kw)
-        
 PGUuid = UUID
 
 class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
@@ -300,13 +286,6 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
     def is_mutable(self):
         return self.mutable
 
-    def adapt(self, impltype):
-        return impltype(
-            self.item_type,
-            mutable=self.mutable,
-            as_tuple=self.as_tuple
-        )
-        
     def bind_processor(self, dialect):
         item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect)
         if item_proc:
@@ -647,7 +626,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
             return "INTERVAL"
 
     def visit_BIT(self, type_):
-        return "BIT"
+        return "BIT(%d)" % type_.length
 
     def visit_UUID(self, type_):
         return "UUID"
@@ -1102,7 +1081,7 @@ class PGDialect(default.DefaultDialect):
             elif attype == 'double precision':
                 args = (53, )
             elif attype == 'integer':
-                args = (32, 0)
+                args = ()
             elif attype in ('timestamp with time zone', 
                             'time with time zone'):
                 kwargs['timezone'] = True
index 447938461e5d048bc84ee3c6e42479ba0966ed39..f5df0236718945dd37f427d1b3709d7368e6546c 100644 (file)
@@ -182,7 +182,7 @@ class TypeEngine(AbstractType):
         return dialect.type_descriptor(self)
         
     def adapt(self, cls, **kw):
-        return cls(**kw)
+        return util.constructor_copy(self, cls, **kw)
     
     def _coerce_compared_value(self, op, value):
         _coerced_type = _type_map.get(type(value), NULLTYPE)
@@ -221,7 +221,7 @@ class TypeEngine(AbstractType):
                         encode('ascii', 'backslashreplace')
         # end Py2K
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self):
         # supports getargspec of the __init__ method
         # used by generic __repr__
         pass
@@ -642,6 +642,9 @@ def adapt_type(typeobj, colspecs):
         return typeobj
     return typeobj.adapt(impltype)
 
+
+
+
 class NullType(TypeEngine):
     """An unknown type.
 
@@ -788,15 +791,6 @@ class String(Concatenable, TypeEngine):
         self.unicode_error = unicode_error
         self._warn_on_bytestring = _warn_on_bytestring
         
-    def adapt(self, impltype, **kw):
-        return impltype(
-                    length=self.length,
-                    convert_unicode=self.convert_unicode,
-                    unicode_error=self.unicode_error,
-                    _warn_on_bytestring=self._warn_on_bytestring,
-                    **kw
-                    )
-
     def bind_processor(self, dialect):
         if self.convert_unicode or dialect.convert_unicode:
             if dialect.supports_unicode_binds and \
@@ -816,10 +810,11 @@ class String(Concatenable, TypeEngine):
                     return None
             else:
                 encoder = codecs.getencoder(dialect.encoding)
+                warn_on_bytestring = self._warn_on_bytestring
                 def process(value):
                     if isinstance(value, unicode):
                         return encoder(value, self.unicode_error)[0]
-                    elif value is not None:
+                    elif warn_on_bytestring and value is not None:
                         util.warn("Unicode type received non-unicode bind "
                                   "param value")
                     return value
@@ -1092,13 +1087,6 @@ class Numeric(_DateAffinity, TypeEngine):
         self.scale = scale
         self.asdecimal = asdecimal
 
-    def adapt(self, impltype, **kw):
-        return impltype(
-                precision=self.precision, 
-                scale=self.scale, 
-                asdecimal=self.asdecimal,
-                **kw)
-
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
@@ -1190,10 +1178,6 @@ class Float(Numeric):
         self.precision = precision
         self.asdecimal = asdecimal
 
-    def adapt(self, impltype, **kw):
-        return impltype(precision=self.precision, 
-                        asdecimal=self.asdecimal, **kw)
-
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
             return processors.to_decimal_processor_factory(decimal.Decimal)
@@ -1240,9 +1224,6 @@ class DateTime(_DateAffinity, TypeEngine):
     def __init__(self, timezone=False):
         self.timezone = timezone
 
-    def adapt(self, impltype, **kw):
-        return impltype(timezone=self.timezone, **kw)
-
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
@@ -1300,9 +1281,6 @@ class Time(_DateAffinity,TypeEngine):
     def __init__(self, timezone=False):
         self.timezone = timezone
 
-    def adapt(self, impltype, **kw):
-        return impltype(timezone=self.timezone, **kw)
-
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
 
@@ -1362,9 +1340,6 @@ class _Binary(TypeEngine):
         else:
             return super(_Binary, self)._coerce_compared_value(op, value)
     
-    def adapt(self, impltype, **kw):
-        return impltype(length=self.length, **kw)
-
     def get_dbapi_type(self, dbapi):
         return dbapi.BINARY
     
index 9119e35b78204d06cd7a6a1e80f8a4e7bc9cfa28..ae1eb3ac5c9289955ba0bd8884fde24ae067b90d 100644 (file)
@@ -24,7 +24,8 @@ from langhelpers import iterate_attributes, class_hierarchy, \
     reset_memoized, group_expirable_memoized_property, importlater, \
     monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\
     duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\
-    classproperty, set_creation_order, warn_exception, warn, NoneType
+    classproperty, set_creation_order, warn_exception, warn, NoneType,\
+    constructor_copy
 
 from deprecations import warn_deprecated, warn_pending_deprecation, \
     deprecated, pending_deprecation
index d85793ee0760560dd125da65070ac4a28515ef8c..945e2a6bd402709d465db76a55f0e4f0bb82ca72 100644 (file)
@@ -516,6 +516,19 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
         else:
             kw[key] = type_(kw[key])
 
+
+def constructor_copy(obj, cls, **kw):
+    """Instantiate cls using the __dict__ of obj as constructor arguments.
+    
+    Uses inspect to match the named arguments of ``cls``.
+    
+    """
+    
+    names = get_cls_kwargs(cls)
+    kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__)
+    return cls(**kw)
+
+
 def duck_type_collection(specimen, default=None):
     """Given an instance or class, guess if it is or is acting as one of
     the basic collection types: list, set and dict.  If the __emulates__