From: Mike Bayer Date: Thu, 6 Apr 2006 01:15:46 +0000 (+0000) Subject: moves the binding of a TypeEngine object from "schema/statement creation" time into... X-Git-Tag: rel_0_1_6~17 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=680c27607328a8f89e446601f7bc7ed56394dc27;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git moves the binding of a TypeEngine object from "schema/statement creation" time into "compilation" time --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index dfc15a3832..40e9466512 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -189,7 +189,10 @@ class ANSICompiler(sql.Compiled): def visit_index(self, index): self.strings[index] = index.name - + + def visit_typeclause(self, typeclause): + self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec() + def visit_textclause(self, textclause): if textclause.parens and len(textclause.text): self.strings[textclause] = "(" + textclause.text + ")" diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 7d5cfed117..7dc48a54a6 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -238,7 +238,7 @@ class FBCompiler(ansisql.ANSICompiler): class FBSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - colspec += " " + column.type.get_col_spec() + colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 582ed90026..6a7ef91b39 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -460,7 +460,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): - colspec = column.name + " " + column.type.get_col_spec() + colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column if column.primary_key and isinstance(column.type, types.Integer): diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index c55da97cb0..a25a21e9bf 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -263,7 +263,7 @@ class MySQLCompiler(ansisql.ANSICompiler): class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): - colspec = column.name + " " + column.type.get_col_spec() + colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index c673cb9617..a475d29b76 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -306,7 +306,7 @@ class OracleCompiler(ansisql.ANSICompiler): class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - colspec += " " + column.type.get_col_spec() + colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 72d4260127..a7285b4b5d 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -305,7 +305,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: - colspec += " " + column.type.get_col_spec() + colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 0e208854e3..a7536ee4e8 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -241,7 +241,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): - colspec = column.name + " " + column.type.get_col_spec() + colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 727ee30ad2..97c7107623 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -319,7 +319,7 @@ class SQLEngine(schema.SchemaEngine): self.positional = True else: raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle) - + def type_descriptor(self, typeobj): """provides a database-specific TypeEngine object, given the generic object which comes from the types module. Subclasses will usually use the adapt_type() @@ -808,7 +808,7 @@ class ResultProxy: rec = self.props[key.lower()] else: rec = self.props[key] - return rec[0].convert_result_value(row[rec[1]], self.engine) + return rec[0].engine_impl(self.engine).convert_result_value(row[rec[1]], self.engine) def __iter__(self): while True: diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py index 2ca3116c1d..38325bea35 100644 --- a/lib/sqlalchemy/ext/proxy.py +++ b/lib/sqlalchemy/ext/proxy.py @@ -36,11 +36,6 @@ class BaseProxyEngine(schema.SchemaEngine): return None return e.oid_column_name() - def type_descriptor(self, typeobj): - """Proxy point: return a ProxyTypeEngine - """ - return ProxyTypeEngine(self, typeobj) - def __getattr__(self, attr): # call get_engine() to give subclasses a chance to change # connection establishment behavior @@ -116,37 +111,3 @@ class ProxyEngine(BaseProxyEngine): self.storage.engine = engine -class ProxyType(object): - """ProxyType base class; used by ProxyTypeEngine to construct proxying - types - """ - def __init__(self, engine, typeobj): - self._engine = engine - self.typeobj = typeobj - - def __getattribute__(self, attr): - if attr.startswith('__') and attr.endswith('__'): - return object.__getattribute__(self, attr) - - engine = object.__getattribute__(self, '_engine').engine - typeobj = object.__getattribute__(self, 'typeobj') - return getattr(engine.type_descriptor(typeobj), attr) - - def __repr__(self): - return '' % (object.__getattribute__(self, 'typeobj')) - -class ProxyTypeEngine(object): - """Proxy type engine; creates dynamic proxy type subclass that is instance - of actual type, but proxies engine-dependant operations through the proxy - engine. - """ - def __new__(cls, engine, typeobj): - """Create a new subclass of ProxyType and typeobj - so that internal isinstance() calls will get the expected result. - """ - if isinstance(typeobj, type): - typeclass = typeobj - else: - typeclass = typeobj.__class__ - typed = type('ProxyTypeHelper', (ProxyType, typeclass), {}) - return typed(engine, typeobj) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index eabfee9bb7..24392b3d97 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -163,7 +163,6 @@ class Table(sql.TableClause, SchemaItem): if column.primary_key: self.primary_key.append(column) column.table = self - column.type = self.engine.type_descriptor(column.type) def append_index(self, index): self.indexes[index.name] = index diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index f0171571d4..f6e2d03c9a 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -139,17 +139,11 @@ def cast(clause, totype, **kwargs): or cast(table.c.timestamp, DATE) """ - engine = kwargs.get('engine', None) - if engine is None: - engine = getattr(clause, 'engine', None) - if engine is not None: - totype_desc = engine.type_descriptor(totype) - # handle non-column clauses (e.g. cast(1234, TEXT) - if not hasattr(clause, 'label'): - clause = literal(clause) - return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs) - else: - raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype)) + # handle non-column clauses (e.g. cast(1234, TEXT) + if not hasattr(clause, 'label'): + clause = literal(clause) + totype = sqltypes.to_instance(totype) + return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs) def exists(*args, **params): params['correlate'] = True @@ -295,7 +289,8 @@ class ClauseVisitor(object): def visit_clauselist(self, list):pass def visit_function(self, func):pass def visit_label(self, label):pass - + def visit_typeclause(self, typeclause):pass + class Compiled(ClauseVisitor): """represents a compiled SQL expression. the __str__ method of the Compiled object should produce the actual text of the statement. Compiled objects are specific to the @@ -671,13 +666,7 @@ class BindParamClause(ClauseElement, CompareMixin): self.key = key self.value = value self.shortname = shortname - self.type = type or sqltypes.NULLTYPE - def _get_convert_type(self, engine): - try: - return self._converted_type - except AttributeError: - self._converted_type = engine.type_descriptor(self.type) - return self._converted_type + self.type = sqltypes.to_instance(type) def accept_visitor(self, visitor): visitor.visit_bindparam(self) def _get_from_objects(self): @@ -685,7 +674,7 @@ class BindParamClause(ClauseElement, CompareMixin): def copy_container(self): return BindParamClause(self.key, self.value, self.shortname, self.type) def typeprocess(self, value, engine): - return self._get_convert_type(engine).convert_bind_param(value, engine) + return self.type.engine_impl(engine).convert_bind_param(value, engine) def compare(self, other): """compares this BindParamClause to the given clause. @@ -695,7 +684,14 @@ class BindParamClause(ClauseElement, CompareMixin): def _make_proxy(self, selectable, name = None): return self # return self.obj._make_proxy(selectable, name=self.name) - + +class TypeClause(ClauseElement): + """handles a type keyword in a SQL statement""" + def __init__(self, type): + self.type = type + def accept_visitor(self, visitor): + visitor.visit_typeclause(self) + class TextClause(ClauseElement): """represents literal a SQL text fragment. public constructor is the text() function. @@ -714,7 +710,7 @@ class TextClause(ClauseElement): self.typemap = typemap if typemap is not None: for key in typemap.keys(): - typemap[key] = engine.type_descriptor(typemap[key]) + typemap[key] = sqltypes.to_instance(typemap[key]) def repl(m): self.bindparams[m.group(1)] = bindparam(m.group(1)) return ":%s" % m.group(1) @@ -820,11 +816,9 @@ class Function(ClauseList, ColumnElement): """describes a SQL function. extends ClauseList to provide comparison operators.""" def __init__(self, name, *clauses, **kwargs): self.name = name - self.type = kwargs.get('type', sqltypes.NULLTYPE) + self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] self._engine = kwargs.get('engine', None) - if self._engine is not None: - self.type = self._engine.type_descriptor(self.type) ClauseList.__init__(self, parens=True, *clauses) key = property(lambda self:self.name) def append(self, clause): @@ -873,7 +867,7 @@ class BinaryClause(ClauseElement): self.left = left self.right = right self.operator = operator - self.type = type + self.type = sqltypes.to_instance(type) self.parens = False if isinstance(self.left, BinaryClause): self.left.parens = True @@ -1028,7 +1022,7 @@ class Label(ColumnElement): while isinstance(obj, Label): obj = obj.obj self.obj = obj - self.type = type or sqltypes.NullTypeEngine() + self.type = sqltypes.to_instance(type) obj.parens=True key = property(lambda s: s.name) @@ -1049,7 +1043,7 @@ class ColumnClause(ColumnElement): def __init__(self, text, selectable=None, type=None): self.key = self.name = self.text = text self.table = selectable - self.type = type or sqltypes.NullTypeEngine() + self.type = sqltypes.to_instance(type) self.__label = None def _get_label(self): if self.__label is None: diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index ecf791a378..7a3822a651 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -16,11 +16,22 @@ try: import cPickle as pickle except: import pickle - + class TypeEngine(object): - basetypes = [] def __init__(self, *args, **kwargs): pass + def _get_impl_dict(self): + try: + return self._impl_dict + except AttributeError: + self._impl_dict = {} + return self._impl_dict + impl_dict = property(_get_impl_dict) + def engine_impl(self, engine): + try: + return self.impl_dict[engine] + except: + return self.impl_dict.setdefault(engine, engine.type_descriptor(self)) def _get_impl(self): if hasattr(self, '_impl'): return self._impl @@ -41,7 +52,14 @@ class TypeEngine(object): return {} def adapt_args(self): return self - + +def to_instance(typeobj): + if typeobj is None: + return NULLTYPE + elif isinstance(typeobj, type): + return typeobj() + else: + return typeobj def adapt_type(typeobj, colspecs): if isinstance(typeobj, type): typeobj = typeobj() diff --git a/test/proxy_engine.py b/test/proxy_engine.py index 170e526d96..2a2cebc5b9 100644 --- a/test/proxy_engine.py +++ b/test/proxy_engine.py @@ -194,7 +194,7 @@ class ProxyEngineTest2(PersistTest): return 'a' def type_descriptor(self, typeobj): - if typeobj == types.Integer: + if isinstance(typeobj, types.Integer): return TypeEngineX2() else: return TypeEngineSTR() @@ -224,16 +224,16 @@ class ProxyEngineTest2(PersistTest): engine = ProxyEngine() engine.storage.engine = EngineA() - a = engine.type_descriptor(sqltypes.Integer) + a = sqltypes.Integer().engine_impl(engine) assert a.convert_bind_param(12, engine) == 24 assert a.convert_bind_param([1,2,3], engine) == [1, 2, 3, 1, 2, 3] - a2 = engine.type_descriptor(sqltypes.String) + a2 = sqltypes.String().engine_impl(engine) assert a2.convert_bind_param(12, engine) == "'12'" assert a2.convert_bind_param([1,2,3], engine) == "'[1, 2, 3]'" engine.storage.engine = EngineB() - b = engine.type_descriptor(sqltypes.Integer) + b = sqltypes.Integer().engine_impl(engine) assert b.convert_bind_param(12, engine) == 'monkey' assert b.convert_bind_param([1,2,3], engine) == 'monkey'