From: Mike Bayer Date: Wed, 15 Dec 2010 17:44:37 +0000 (-0500) Subject: - an approach I like better, remove most adapt() methods and use a generic X-Git-Tag: rel_0_7b1~159 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5a832a49e37ca9259fbad286335367927d0ec60e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - an approach I like better, remove most adapt() methods and use a generic copier - mssql reflection fix, but this will come in again from the tip merge --- diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index 75ea912874..cf35b3e0a3 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -50,15 +50,10 @@ class AcSmallInteger(types.SmallInteger): return "SMALLINT" class AcDateTime(types.DateTime): - def __init__(self, *a, **kw): - super(AcDateTime, self).__init__(False) - def get_col_spec(self): return "DATETIME" class AcDate(types.Date): - def __init__(self, *a, **kw): - super(AcDate, self).__init__(False) def get_col_spec(self): return "DATETIME" diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index 9a1e10f517..3d45bb6700 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -116,15 +116,13 @@ class _StringType(sqltypes.String): class MaxString(_StringType): _type = 'VARCHAR' - def __init__(self, *a, **kw): - super(MaxString, self).__init__(*a, **kw) - class MaxUnicode(_StringType): _type = 'VARCHAR' def __init__(self, length=None, **kw): - super(MaxUnicode, self).__init__(length=length, encoding='unicode') + kw['encoding'] = 'unicode' + super(MaxUnicode, self).__init__(length=length, **kw) class MaxChar(_StringType): @@ -134,8 +132,8 @@ class MaxChar(_StringType): class MaxText(_StringType): _type = 'LONG' - def __init__(self, *a, **kw): - super(MaxText, self).__init__(*a, **kw) + def __init__(self, length=None, **kw): + super(MaxText, self).__init__(length, **kw) def get_col_spec(self): spec = 'LONG' diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 4c0a008904..c5f891fb70 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -280,16 +280,15 @@ class _StringType(object): class TEXT(_StringType, sqltypes.TEXT): """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a TEXT. :param collation: Optional, a column-level collation for this string value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.Text.__init__(self, *args, **kw) + sqltypes.Text.__init__(self, length, **kw) class NTEXT(_StringType, sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 @@ -297,24 +296,22 @@ class NTEXT(_StringType, sqltypes.UnicodeText): __visit_name__ = 'NTEXT' - def __init__(self, *args, **kwargs): + def __init__(self, length=None, collation=None, **kw): """Construct a NTEXT. :param collation: Optional, a column-level collation for this string value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kwargs.pop('collation', None) _StringType.__init__(self, collation) - length = kwargs.pop('length', None) - sqltypes.UnicodeText.__init__(self, length, **kwargs) + sqltypes.UnicodeText.__init__(self, length, **kw) class VARCHAR(_StringType, sqltypes.VARCHAR): """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum of 8,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a VARCHAR. :param length: Optinal, maximum data length, in characters. @@ -335,16 +332,15 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.VARCHAR.__init__(self, *args, **kw) + sqltypes.VARCHAR.__init__(self, length, **kw) class NVARCHAR(_StringType, sqltypes.NVARCHAR): """MSSQL NVARCHAR type. For variable-length unicode character data up to 4,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a NVARCHAR. :param length: Optional, Maximum data length, in characters. @@ -353,15 +349,14 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.NVARCHAR.__init__(self, *args, **kw) + sqltypes.NVARCHAR.__init__(self, length, **kw) class CHAR(_StringType, sqltypes.CHAR): """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum of 8,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct a CHAR. :param length: Optinal, maximum data length, in characters. @@ -382,16 +377,15 @@ class CHAR(_StringType, sqltypes.CHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.CHAR.__init__(self, *args, **kw) + sqltypes.CHAR.__init__(self, length, **kw) class NCHAR(_StringType, sqltypes.NCHAR): """MSSQL NCHAR type. For fixed-length unicode character data up to 4,000 characters.""" - def __init__(self, *args, **kw): + def __init__(self, length=None, collation=None, **kw): """Construct an NCHAR. :param length: Optional, Maximum data length, in characters. @@ -400,9 +394,8 @@ class NCHAR(_StringType, sqltypes.NCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - collation = kw.pop('collation', None) _StringType.__init__(self, collation) - sqltypes.NCHAR.__init__(self, *args, **kw) + sqltypes.NCHAR.__init__(self, length, **kw) class IMAGE(sqltypes.LargeBinary): __visit_name__ = 'IMAGE' @@ -1150,8 +1143,8 @@ class MSDialect(default.DefaultDialect): "and sch.name=:schname " "and ind.is_primary_key=0", bindparams=[ - sql.bindparam('tabname', tablename, sqltypes.Unicode), - sql.bindparam('schname', current_schema, sqltypes.Unicode) + sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', current_schema, sqltypes.String(convert_unicode=True)) ] ) ) @@ -1163,16 +1156,19 @@ class MSDialect(default.DefaultDialect): 'column_names':[] } rp = connection.execute( - sql.text("select ind_col.index_id, col.name from sys.columns as col " - "join sys.index_columns as ind_col on " - "ind_col.column_id=col.column_id " - "join sys.tables as tab on tab.object_id=col.object_id " - "join sys.schemas as sch on sch.schema_id=tab.schema_id " - "where tab.name=:tabname " - "and sch.name=:schname", + sql.text( + "select ind_col.index_id, ind_col.object_id, col.name " + "from sys.columns as col " + "join sys.tables as tab on tab.object_id=col.object_id " + "join sys.index_columns as ind_col on " + "(ind_col.column_id=col.column_id and " + "ind_col.object_id=tab.object_id) " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name=:tabname " + "and sch.name=:schname", bindparams=[ - sql.bindparam('tabname', tablename, sqltypes.Unicode), - sql.bindparam('schname', current_schema, sqltypes.Unicode) + sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)), + sql.bindparam('schname', current_schema, sqltypes.String(convert_unicode=True)) ]), ) for row in rp: diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 5c3289bfb0..528e949651 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -233,17 +233,11 @@ SET_RE = re.compile( class _NumericType(object): """Base for MySQL numeric types.""" - def __init__(self, **kw): - self.unsigned = kw.pop('unsigned', False) - self.zerofill = kw.pop('zerofill', False) + def __init__(self, unsigned=False, zerofill=False, **kw): + self.unsigned = unsigned + self.zerofill = zerofill super(_NumericType, self).__init__(**kw) - def adapt(self, typeimpl, **kw): - return super(_NumericType, self).adapt( - typeimpl, - unsigned=self.unsigned, - zerofill=self.zerofill) - class _FloatType(_NumericType, sqltypes.Float): def __init__(self, precision=None, scale=None, asdecimal=True, **kw): if isinstance(self, (REAL, DOUBLE)) and \ @@ -263,11 +257,6 @@ class _IntegerType(_NumericType, sqltypes.Integer): self.display_width = display_width super(_IntegerType, self).__init__(**kw) - def adapt(self, typeimpl, **kw): - return super(_IntegerType, self).adapt( - typeimpl, - display_width=self.display_width) - class _StringType(sqltypes.String): """Base for MySQL string types.""" @@ -288,17 +277,6 @@ class _StringType(sqltypes.String): self.national = national super(_StringType, self).__init__(**kw) - def adapt(self, typeimpl, **kw): - return super(_StringType, self).adapt( - typeimpl, - charset=self.charset, - collation=self.collation, - ascii=self.ascii, - binary=self.binary, - national=self.national, - **kw - ) - def __repr__(self): attributes = inspect.getargspec(self.__init__)[0][1:] attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 2569726964..3d97b504ec 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -215,10 +215,6 @@ class INTERVAL(sqltypes.TypeEngine): return INTERVAL(day_precision=interval.day_precision, second_precision=interval.second_precision) - def adapt(self, impltype): - return impltype(day_precision=self.day_precision, - second_precision=self.second_precision) - @property def _type_affinity(self): return sqltypes.Interval diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c9920c9302..72b58a71c1 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -133,23 +133,12 @@ class TIMESTAMP(sqltypes.TIMESTAMP): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision - def adapt(self, impltype, **kw): - return impltype( - precision=self.precision, - timezone=self.timezone, - **kw) class TIME(sqltypes.TIME): def __init__(self, timezone=False, precision=None): super(TIME, self).__init__(timezone=timezone) self.precision = precision - def adapt(self, impltype, **kw): - return impltype( - precision=self.precision, - timezone=self.timezone, - **kw) - class INTERVAL(sqltypes.TypeEngine): """Postgresql INTERVAL type. @@ -161,9 +150,6 @@ class INTERVAL(sqltypes.TypeEngine): def __init__(self, precision=None): self.precision = precision - def adapt(self, impltype): - return impltype(self.precision) - @classmethod def _adapt_from_generic_interval(cls, interval): return INTERVAL(precision=interval.second_precision) @@ -176,6 +162,9 @@ PGInterval = INTERVAL class BIT(sqltypes.TypeEngine): __visit_name__ = 'BIT' + def __init__(self, length=1): + self.length= length + PGBit = BIT class UUID(sqltypes.TypeEngine): @@ -226,9 +215,6 @@ class UUID(sqltypes.TypeEngine): else: return None - def adapt(self, impltype, **kw): - return impltype(as_uuid=self.as_uuid, **kw) - PGUuid = UUID class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): @@ -300,13 +286,6 @@ class ARRAY(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): def is_mutable(self): return self.mutable - def adapt(self, impltype): - return impltype( - self.item_type, - mutable=self.mutable, - as_tuple=self.as_tuple - ) - def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect) if item_proc: @@ -647,7 +626,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): return "INTERVAL" def visit_BIT(self, type_): - return "BIT" + return "BIT(%d)" % type_.length def visit_UUID(self, type_): return "UUID" @@ -1102,7 +1081,7 @@ class PGDialect(default.DefaultDialect): elif attype == 'double precision': args = (53, ) elif attype == 'integer': - args = (32, 0) + args = () elif attype in ('timestamp with time zone', 'time with time zone'): kwargs['timezone'] = True diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 447938461e..f5df023671 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -182,7 +182,7 @@ class TypeEngine(AbstractType): return dialect.type_descriptor(self) def adapt(self, cls, **kw): - return cls(**kw) + return util.constructor_copy(self, cls, **kw) def _coerce_compared_value(self, op, value): _coerced_type = _type_map.get(type(value), NULLTYPE) @@ -221,7 +221,7 @@ class TypeEngine(AbstractType): encode('ascii', 'backslashreplace') # end Py2K - def __init__(self, *args, **kwargs): + def __init__(self): # supports getargspec of the __init__ method # used by generic __repr__ pass @@ -642,6 +642,9 @@ def adapt_type(typeobj, colspecs): return typeobj return typeobj.adapt(impltype) + + + class NullType(TypeEngine): """An unknown type. @@ -788,15 +791,6 @@ class String(Concatenable, TypeEngine): self.unicode_error = unicode_error self._warn_on_bytestring = _warn_on_bytestring - def adapt(self, impltype, **kw): - return impltype( - length=self.length, - convert_unicode=self.convert_unicode, - unicode_error=self.unicode_error, - _warn_on_bytestring=self._warn_on_bytestring, - **kw - ) - def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: if dialect.supports_unicode_binds and \ @@ -816,10 +810,11 @@ class String(Concatenable, TypeEngine): return None else: encoder = codecs.getencoder(dialect.encoding) + warn_on_bytestring = self._warn_on_bytestring def process(value): if isinstance(value, unicode): return encoder(value, self.unicode_error)[0] - elif value is not None: + elif warn_on_bytestring and value is not None: util.warn("Unicode type received non-unicode bind " "param value") return value @@ -1092,13 +1087,6 @@ class Numeric(_DateAffinity, TypeEngine): self.scale = scale self.asdecimal = asdecimal - def adapt(self, impltype, **kw): - return impltype( - precision=self.precision, - scale=self.scale, - asdecimal=self.asdecimal, - **kw) - def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -1190,10 +1178,6 @@ class Float(Numeric): self.precision = precision self.asdecimal = asdecimal - def adapt(self, impltype, **kw): - return impltype(precision=self.precision, - asdecimal=self.asdecimal, **kw) - def result_processor(self, dialect, coltype): if self.asdecimal: return processors.to_decimal_processor_factory(decimal.Decimal) @@ -1240,9 +1224,6 @@ class DateTime(_DateAffinity, TypeEngine): def __init__(self, timezone=False): self.timezone = timezone - def adapt(self, impltype, **kw): - return impltype(timezone=self.timezone, **kw) - def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -1300,9 +1281,6 @@ class Time(_DateAffinity,TypeEngine): def __init__(self, timezone=False): self.timezone = timezone - def adapt(self, impltype, **kw): - return impltype(timezone=self.timezone, **kw) - def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -1362,9 +1340,6 @@ class _Binary(TypeEngine): else: return super(_Binary, self)._coerce_compared_value(op, value) - def adapt(self, impltype, **kw): - return impltype(length=self.length, **kw) - def get_dbapi_type(self, dbapi): return dbapi.BINARY diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 9119e35b78..ae1eb3ac5c 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -24,7 +24,8 @@ from langhelpers import iterate_attributes, class_hierarchy, \ reset_memoized, group_expirable_memoized_property, importlater, \ monkeypatch_proxied_specials, asbool, bool_or_str, coerce_kw_type,\ duck_type_collection, assert_arg_type, symbol, dictlike_iteritems,\ - classproperty, set_creation_order, warn_exception, warn, NoneType + classproperty, set_creation_order, warn_exception, warn, NoneType,\ + constructor_copy from deprecations import warn_deprecated, warn_pending_deprecation, \ deprecated, pending_deprecation diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index d85793ee07..945e2a6bd4 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -516,6 +516,19 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): else: kw[key] = type_(kw[key]) + +def constructor_copy(obj, cls, **kw): + """Instantiate cls using the __dict__ of obj as constructor arguments. + + Uses inspect to match the named arguments of ``cls``. + + """ + + names = get_cls_kwargs(cls) + kw.update((k, obj.__dict__[k]) for k in names if k in obj.__dict__) + return cls(**kw) + + def duck_type_collection(specimen, default=None): """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__