]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moves the binding of a TypeEngine object from "schema/statement creation" time into...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 01:15:46 +0000 (01:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Apr 2006 01:15:46 +0000 (01:15 +0000)
13 files changed:
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/ext/proxy.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/types.py
test/proxy_engine.py

index dfc15a3832ff3db9ffe39bdcc4227c8d422540da..40e9466512eccfa41a97e4d14bc3c37f7e6df35f 100644 (file)
@@ -189,7 +189,10 @@ class ANSICompiler(sql.Compiled):
 
     def visit_index(self, index):
         self.strings[index] = index.name
-        
+    
+    def visit_typeclause(self, typeclause):
+        self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec()
+            
     def visit_textclause(self, textclause):
         if textclause.parens and len(textclause.text):
             self.strings[textclause] = "(" + textclause.text + ")"
index 7d5cfed1170f067b6cd679e5237137855f3c724f..7dc48a54a653cc11c0c752206337d4ac7bdbfd66 100644 (file)
@@ -238,7 +238,7 @@ class FBCompiler(ansisql.ANSICompiler):
 class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name 
-        colspec += " " + column.type.get_col_spec()
+        colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index 582ed900260f7bbca27a86b7d9a043273a697a99..6a7ef91b39d712d0f99fcb67aecf5b7ef81c2b32 100644 (file)
@@ -460,7 +460,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
         
 class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
-        colspec = column.name + " " + column.type.get_col_spec()
+        colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
 
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
         if column.primary_key and isinstance(column.type, types.Integer):
index c55da97cb0263a66620053b5e1517d0e30044f01..a25a21e9bf06ee2de09dd331592c8fc8a50f880e 100644 (file)
@@ -263,7 +263,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
         
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
-        colspec = column.name + " " + column.type.get_col_spec()
+        colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index c673cb9617e9a2ae78d139caf94d08699a4ad87c..a475d29b769fee17f41b0afcb902807ade3a91e0 100644 (file)
@@ -306,7 +306,7 @@ class OracleCompiler(ansisql.ANSICompiler):
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
-        colspec += " " + column.type.get_col_spec()
+        colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index 72d4260127bf31ec1d2b116493ccb9f389daaeb6..a7285b4b5d1c2edde36036ebf1f2d83c9dbb31e0 100644 (file)
@@ -305,7 +305,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
         if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
         else:
-            colspec += " " + column.type.get_col_spec()
+            colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
index 0e208854e3f91d9ee530d4e8387c3ecafb4da747..a7536ee4e884aec1e1dcb8997a60bcf20aaaa5ad 100644 (file)
@@ -241,7 +241,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
         
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
-        colspec = column.name + " " + column.type.get_col_spec()
+        colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index 727ee30ad2201e34011b4adceba6d64324c03623..97c71076234462febbd00e8e6c3b42f9c4afd086 100644 (file)
@@ -319,7 +319,7 @@ class SQLEngine(schema.SchemaEngine):
             self.positional = True
         else:
             raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
-        
+    
     def type_descriptor(self, typeobj):
         """provides a database-specific TypeEngine object, given the generic object
         which comes from the types module.  Subclasses will usually use the adapt_type()
@@ -808,7 +808,7 @@ class ResultProxy:
             rec = self.props[key.lower()]
         else:
             rec = self.props[key]
-        return rec[0].convert_result_value(row[rec[1]], self.engine)
+        return rec[0].engine_impl(self.engine).convert_result_value(row[rec[1]], self.engine)
     
     def __iter__(self):
         while True:
index 2ca3116c1d5de60931cafda6e6915739afdeb825..38325bea35e38de267bd7185c8cedd5c54269c05 100644 (file)
@@ -36,11 +36,6 @@ class BaseProxyEngine(schema.SchemaEngine):
             return None
         return e.oid_column_name()    
         
-    def type_descriptor(self, typeobj):
-        """Proxy point: return a ProxyTypeEngine 
-        """
-        return ProxyTypeEngine(self, typeobj)
-
     def __getattr__(self, attr):
         # call get_engine() to give subclasses a chance to change
         # connection establishment behavior
@@ -116,37 +111,3 @@ class ProxyEngine(BaseProxyEngine):
         self.storage.engine = engine
         
 
-class ProxyType(object):
-    """ProxyType base class; used by ProxyTypeEngine to construct proxying
-    types    
-    """
-    def __init__(self, engine, typeobj):
-        self._engine = engine
-        self.typeobj = typeobj
-
-    def __getattribute__(self, attr):
-        if attr.startswith('__') and attr.endswith('__'):
-            return object.__getattribute__(self, attr)
-        
-        engine = object.__getattribute__(self, '_engine').engine
-        typeobj = object.__getattribute__(self, 'typeobj')        
-        return getattr(engine.type_descriptor(typeobj), attr)
-
-    def __repr__(self):
-        return '<Proxy %s>' % (object.__getattribute__(self, 'typeobj'))
-    
-class ProxyTypeEngine(object):
-    """Proxy type engine; creates dynamic proxy type subclass that is instance
-    of actual type, but proxies engine-dependant operations through the proxy
-    engine.    
-    """
-    def __new__(cls, engine, typeobj):
-        """Create a new subclass of ProxyType and typeobj
-        so that internal isinstance() calls will get the expected result.
-        """
-        if isinstance(typeobj, type):
-            typeclass = typeobj
-        else:
-            typeclass = typeobj.__class__
-        typed = type('ProxyTypeHelper', (ProxyType, typeclass), {})
-        return typed(engine, typeobj)    
index eabfee9bb7aa25ade57af6bf0a315a559219f0c5..24392b3d973932376f993179896987c3f44427e9 100644 (file)
@@ -163,7 +163,6 @@ class Table(sql.TableClause, SchemaItem):
         if column.primary_key:
             self.primary_key.append(column)
         column.table = self
-        column.type = self.engine.type_descriptor(column.type)
 
     def append_index(self, index):
         self.indexes[index.name] = index
index f0171571d45aafed5cfb995b804dcc20883eee8b..f6e2d03c9ac0bb028097eb89eb0602d02b1b416e 100644 (file)
@@ -139,17 +139,11 @@ def cast(clause, totype, **kwargs):
          or
         cast(table.c.timestamp, DATE)
     """
-    engine = kwargs.get('engine', None)
-    if engine is None:
-        engine = getattr(clause, 'engine', None)
-    if engine is not None:
-        totype_desc = engine.type_descriptor(totype)
-        # handle non-column clauses (e.g. cast(1234, TEXT)
-        if not hasattr(clause, 'label'):
-            clause = literal(clause)
-        return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs)
-    else:
-        raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype))
+    # handle non-column clauses (e.g. cast(1234, TEXT)
+    if not hasattr(clause, 'label'):
+        clause = literal(clause)
+    totype = sqltypes.to_instance(totype)
+    return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs)
         
 def exists(*args, **params):
     params['correlate'] = True
@@ -295,7 +289,8 @@ class ClauseVisitor(object):
     def visit_clauselist(self, list):pass
     def visit_function(self, func):pass
     def visit_label(self, label):pass
-        
+    def visit_typeclause(self, typeclause):pass
+            
 class Compiled(ClauseVisitor):
     """represents a compiled SQL expression.  the __str__ method of the Compiled object
     should produce the actual text of the statement.  Compiled objects are specific to the
@@ -671,13 +666,7 @@ class BindParamClause(ClauseElement, CompareMixin):
         self.key = key
         self.value = value
         self.shortname = shortname
-        self.type = type or sqltypes.NULLTYPE
-    def _get_convert_type(self, engine):
-        try:
-            return self._converted_type
-        except AttributeError:
-            self._converted_type = engine.type_descriptor(self.type)
-            return self._converted_type
+        self.type = sqltypes.to_instance(type)
     def accept_visitor(self, visitor):
         visitor.visit_bindparam(self)
     def _get_from_objects(self):
@@ -685,7 +674,7 @@ class BindParamClause(ClauseElement, CompareMixin):
     def copy_container(self):
         return BindParamClause(self.key, self.value, self.shortname, self.type)
     def typeprocess(self, value, engine):
-        return self._get_convert_type(engine).convert_bind_param(value, engine)
+        return self.type.engine_impl(engine).convert_bind_param(value, engine)
     def compare(self, other):
         """compares this BindParamClause to the given clause.
         
@@ -695,7 +684,14 @@ class BindParamClause(ClauseElement, CompareMixin):
     def _make_proxy(self, selectable, name = None):
         return self
 #        return self.obj._make_proxy(selectable, name=self.name)
-            
+
+class TypeClause(ClauseElement):
+    """handles a type keyword in a SQL statement"""
+    def __init__(self, type):
+        self.type = type
+    def accept_visitor(self, visitor):
+        visitor.visit_typeclause(self)
+               
 class TextClause(ClauseElement):
     """represents literal a SQL text fragment.  public constructor is the 
     text() function.  
@@ -714,7 +710,7 @@ class TextClause(ClauseElement):
         self.typemap = typemap
         if typemap is not None:
             for key in typemap.keys():
-                typemap[key] = engine.type_descriptor(typemap[key])
+                typemap[key] = sqltypes.to_instance(typemap[key])
         def repl(m):
             self.bindparams[m.group(1)] = bindparam(m.group(1))
             return ":%s" % m.group(1)
@@ -820,11 +816,9 @@ class Function(ClauseList, ColumnElement):
     """describes a SQL function. extends ClauseList to provide comparison operators."""
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
-        self.type = kwargs.get('type', sqltypes.NULLTYPE)
+        self.type = sqltypes.to_instance(kwargs.get('type', None))
         self.packagenames = kwargs.get('packagenames', None) or []
         self._engine = kwargs.get('engine', None)
-        if self._engine is not None:
-            self.type = self._engine.type_descriptor(self.type)
         ClauseList.__init__(self, parens=True, *clauses)
     key = property(lambda self:self.name)
     def append(self, clause):
@@ -873,7 +867,7 @@ class BinaryClause(ClauseElement):
         self.left = left
         self.right = right
         self.operator = operator
-        self.type = type
+        self.type = sqltypes.to_instance(type)
         self.parens = False
         if isinstance(self.left, BinaryClause):
             self.left.parens = True
@@ -1028,7 +1022,7 @@ class Label(ColumnElement):
         while isinstance(obj, Label):
             obj = obj.obj
         self.obj = obj
-        self.type = type or sqltypes.NullTypeEngine()
+        self.type = sqltypes.to_instance(type)
         obj.parens=True
     key = property(lambda s: s.name)
     
@@ -1049,7 +1043,7 @@ class ColumnClause(ColumnElement):
     def __init__(self, text, selectable=None, type=None):
         self.key = self.name = self.text = text
         self.table = selectable
-        self.type = type or sqltypes.NullTypeEngine()
+        self.type = sqltypes.to_instance(type)
         self.__label = None
     def _get_label(self):
         if self.__label is None:
index ecf791a378c98d951d8ca8f590a208751152c045..7a3822a6518e4458420e379b3243dfdfcea1c852 100644 (file)
@@ -16,11 +16,22 @@ try:
     import cPickle as pickle
 except:
     import pickle
-    
+
 class TypeEngine(object):
-    basetypes = []
     def __init__(self, *args, **kwargs):
         pass
+    def _get_impl_dict(self):
+        try:
+            return self._impl_dict
+        except AttributeError:
+            self._impl_dict = {}
+            return self._impl_dict
+    impl_dict = property(_get_impl_dict)
+    def engine_impl(self, engine):
+        try:
+            return self.impl_dict[engine]
+        except:
+            return self.impl_dict.setdefault(engine, engine.type_descriptor(self))
     def _get_impl(self):
         if hasattr(self, '_impl'):
             return self._impl
@@ -41,7 +52,14 @@ class TypeEngine(object):
         return {}
     def adapt_args(self):
         return self
-            
+
+def to_instance(typeobj):
+    if typeobj is None:
+        return NULLTYPE
+    elif isinstance(typeobj, type):
+        return typeobj()
+    else:
+        return typeobj
 def adapt_type(typeobj, colspecs):
     if isinstance(typeobj, type):
         typeobj = typeobj()
index 170e526d96d7af107209f4f69a69fa51410422c9..2a2cebc5b90b0ea0a91fe55ac8e6ea029262d345 100644 (file)
@@ -194,7 +194,7 @@ class ProxyEngineTest2(PersistTest):
                 return 'a'
             
             def type_descriptor(self, typeobj):
-                if typeobj == types.Integer:
+                if isinstance(typeobj, types.Integer):
                     return TypeEngineX2()
                 else:
                     return TypeEngineSTR()
@@ -224,16 +224,16 @@ class ProxyEngineTest2(PersistTest):
         engine = ProxyEngine()
         engine.storage.engine = EngineA()
 
-        a = engine.type_descriptor(sqltypes.Integer)
+        a = sqltypes.Integer().engine_impl(engine)
         assert a.convert_bind_param(12, engine) == 24
         assert a.convert_bind_param([1,2,3], engine) == [1, 2, 3, 1, 2, 3]
 
-        a2 = engine.type_descriptor(sqltypes.String)
+        a2 = sqltypes.String().engine_impl(engine)
         assert a2.convert_bind_param(12, engine) == "'12'"
         assert a2.convert_bind_param([1,2,3], engine) == "'[1, 2, 3]'"
         
         engine.storage.engine = EngineB()
-        b = engine.type_descriptor(sqltypes.Integer)
+        b = sqltypes.Integer().engine_impl(engine)
         assert b.convert_bind_param(12, engine) == 'monkey'
         assert b.convert_bind_param([1,2,3], engine) == 'monkey'