]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- system to cache the bind/result processors in a dialect-wide registry.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Dec 2010 17:53:56 +0000 (12:53 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Dec 2010 17:53:56 +0000 (12:53 -0500)
its an idea with pointy edges.

lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util/_collections.py
test/aaa_profiling/test_memusage.py

index 42072699e0befd84a35d6d89aa68af72b025aa4e..deeebf0f906c247d54f126da6f6a75c902e970c5 100644 (file)
@@ -237,6 +237,12 @@ class _NumericType(object):
         self.unsigned = kw.pop('unsigned', False)
         self.zerofill = kw.pop('zerofill', False)
         super(_NumericType, self).__init__(**kw)
+    
+    def adapt(self, typeimpl, **kw):
+        return super(_NumericType, self).adapt(
+                        typeimpl, 
+                        unsigned=self.unsigned, 
+                        zerofill=self.zerofill)
         
 class _FloatType(_NumericType, sqltypes.Float):
     def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
@@ -257,6 +263,11 @@ class _IntegerType(_NumericType, sqltypes.Integer):
         self.display_width = display_width
         super(_IntegerType, self).__init__(**kw)
 
+    def adapt(self, typeimpl, **kw):
+        return super(_IntegerType, self).adapt(
+                        typeimpl, 
+                        display_width=self.display_width)
+
 class _StringType(sqltypes.String):
     """Base for MySQL string types."""
 
@@ -276,6 +287,17 @@ class _StringType(sqltypes.String):
         self.binary = binary
         self.national = national
         super(_StringType, self).__init__(**kw)
+    
+    def adapt(self, typeimpl, **kw):
+        return super(_StringType, self).adapt(
+            typeimpl,
+            charset=self.charset,
+            collation=self.collation,
+            ascii=self.ascii,
+            binary=self.binary,
+            national=self.national,
+            **kw
+        )
         
     def __repr__(self):
         attributes = inspect.getargspec(self.__init__)[0][1:]
@@ -990,8 +1012,8 @@ class SET(_StringType):
             strip_values.append(a)
 
         self.values = strip_values
-        length = max([len(v) for v in strip_values] + [0])
-        super(SET, self).__init__(length=length, **kw)
+        kw.setdefault('length', max([len(v) for v in strip_values] + [0]))
+        super(SET, self).__init__(**kw)
 
     def result_processor(self, dialect, coltype):
         def process(value):
index a74ea0c3c93eb170ffcffd8ca79667393a42fa50..47a43bdf115641b4b21eb387849f52b7df5a7c8e 100644 (file)
@@ -72,7 +72,8 @@ class _DateTimeMixin(object):
     _reg = None
     _storage_format = None
     
-    def __init__(self, storage_format=None, regexp=None, **kwargs):
+    def __init__(self, storage_format=None, regexp=None, **kw):
+        super(_DateTimeMixin, self).__init__(**kw)
         if regexp is not None:
             self._reg = re.compile(regexp)
         if storage_format is not None:
index 8647ba385f2ca6df811f4ffc11f48f89b7ddb562..b432c351d648d551a55c5665a8d4b5f6b85608fd 100644 (file)
@@ -16,6 +16,7 @@ import re, random
 from sqlalchemy.engine import base, reflection
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy import exc, types as sqltypes, util, pool
+import weakref
 
 AUTOCOMMIT_REGEXP = re.compile(
             r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
@@ -133,7 +134,9 @@ class DefaultDialect(base.Dialect):
                     " maximum identifier length of %d" %
                     (label_length, self.max_identifier_length))
         self.label_length = label_length
-
+        
+        self._type_memos = weakref.WeakKeyDictionary()
+        
         if not hasattr(self, 'description_encoding'):
             self.description_encoding = getattr(
                                             self, 
index 1fa93368306ecf3400a95f7ddee0ac9a1debaed1..85ac3192ff65a8bc346f276766728399d6a70dc2 100644 (file)
@@ -131,57 +131,57 @@ class TypeEngine(AbstractType):
         else:
             return self.__class__
 
-    @util.memoized_property
-    def _impl_dict(self):
-        return {}
-
-    def __getstate__(self):
-        d = self.__dict__.copy()
-        d.pop('_impl_dict', None)
-        return d
-
     def dialect_impl(self, dialect, **kwargs):
-        key = dialect.__class__, dialect.server_version_info
+        """Return a dialect-specific implementation for this type."""
+        
         try:
-            return self._impl_dict[key]
+            return dialect._type_memos[self]['impl']
         except KeyError:
-            return self._impl_dict.setdefault(key,
-                    dialect.type_descriptor(self))
+            return self._dialect_info(dialect)['impl']
     
     def _cached_bind_processor(self, dialect):
-        return self.dialect_impl(dialect).bind_processor(dialect)
-        
-        # TODO: can't do this until we find a way to link with the
-        # specific attributes of the dialect, i.e. convert_unicode,
-        # etc.  might need to do a weakmap again.  needs tests
-        # to ensure two dialects with different flags.  use a mock
-        # dialect.
-        #key = "bind", dialect.__class__, dialect.server_version_info
-        #try:
-        #    return self._impl_dict[key]
-        #except KeyError:
-        #    self._impl_dict[key] = bp = \
-        #                self.dialect_impl(dialect).bind_processor(dialect)
-        #    return bp
+        """Return a dialect-specific bind processor for this type."""
 
+        try:
+            return dialect._type_memos[self]['bind']
+        except KeyError:
+            d = self._dialect_info(dialect)
+            d['bind'] = bp = d['impl'].bind_processor(dialect)
+            return bp
+    
+    def _dialect_info(self, dialect):
+        """Return a dialect-specific registry containing bind/result processors."""
+        
+        if self in dialect._type_memos:
+            return dialect._type_memos[self]
+        else:
+            impl = self._gen_dialect_impl(dialect)
+            # the impl we put in here
+            # must not have any references to self.
+            if impl is self:
+                impl = self.adapt(type(self))
+            dialect._type_memos[self] = d = {
+                'impl':impl,
+            }
+            return d
+        
     def _cached_result_processor(self, dialect, coltype):
-        return self.dialect_impl(dialect).result_processor(dialect, coltype)
+        """Return a dialect-specific result processor for this type."""
+
+        try:
+            return dialect._type_memos[self][coltype]
+        except KeyError:
+            d = self._dialect_info(dialect)
+            # another key assumption.  DBAPI type codes are
+            # constants.   
+            d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
+            return rp
+
+    def _gen_dialect_impl(self, dialect):
+        return dialect.type_descriptor(self)
         
-        # TODO: can't do this until we find a way to link with the
-        # specific attributes of the dialect, i.e. convert_unicode,
-        # etc.  might need to do a weakmap again.   needs tests
-        # to ensure two dialects with different flags.  use a mock
-        # dialect.
-        #key = "result", dialect.__class__, dialect.server_version_info, coltype
-        #try:
-        #    return self._impl_dict[key]
-        #except KeyError:
-        #    self._impl_dict[key] = rp = self.dialect_impl(dialect).\
-        #                result_processor(dialect, coltype)
-        #    return rp
-
-    def adapt(self, cls):
-        return cls()
+    def adapt(self, cls, **kw):
+        return cls(**kw)
     
     def _coerce_compared_value(self, op, value):
         _coerced_type = _type_map.get(type(value), NULLTYPE)
@@ -376,17 +376,10 @@ class TypeDecorator(TypeEngine):
                                  "type being decorated")
         self.impl = to_instance(self.__class__.impl, *args, **kwargs)
     
-    def dialect_impl(self, dialect):
-        key = (dialect.__class__, dialect.server_version_info)
-
-        try:
-            return self._impl_dict[key]
-        except KeyError:
-            pass
-
+    
+    def _gen_dialect_impl(self, dialect):
         adapted = dialect.type_descriptor(self)
         if adapted is not self:
-            self._impl_dict[key] = adapted
             return adapted
 
         # otherwise adapt the impl type, link
@@ -400,7 +393,6 @@ class TypeDecorator(TypeEngine):
                                  'return an object of type %s' % (self,
                                  self.__class__))
         tt.impl = typedesc
-        self._impl_dict[key] = tt
         return tt
 
     @util.memoized_property
@@ -499,7 +491,6 @@ class TypeDecorator(TypeEngine):
     def copy(self):
         instance = self.__class__.__new__(self.__class__)
         instance.__dict__.update(self.__dict__)
-        instance._impl_dict = {}
         return instance
 
     def get_dbapi_type(self, dbapi):
@@ -796,12 +787,13 @@ class String(Concatenable, TypeEngine):
         self.unicode_error = unicode_error
         self._warn_on_bytestring = _warn_on_bytestring
         
-    def adapt(self, impltype):
+    def adapt(self, impltype, **kw):
         return impltype(
                     length=self.length,
                     convert_unicode=self.convert_unicode,
                     unicode_error=self.unicode_error,
                     _warn_on_bytestring=True,
+                    **kw
                     )
 
     def bind_processor(self, dialect):
@@ -1099,11 +1091,12 @@ class Numeric(_DateAffinity, TypeEngine):
         self.scale = scale
         self.asdecimal = asdecimal
 
-    def adapt(self, impltype):
+    def adapt(self, impltype, **kw):
         return impltype(
                 precision=self.precision, 
                 scale=self.scale, 
-                asdecimal=self.asdecimal)
+                asdecimal=self.asdecimal,
+                **kw)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
@@ -1194,8 +1187,9 @@ class Float(Numeric):
         self.precision = precision
         self.asdecimal = asdecimal
 
-    def adapt(self, impltype):
-        return impltype(precision=self.precision, asdecimal=self.asdecimal)
+    def adapt(self, impltype, **kw):
+        return impltype(precision=self.precision, 
+                        asdecimal=self.asdecimal, **kw)
 
     def result_processor(self, dialect, coltype):
         if self.asdecimal:
@@ -1243,8 +1237,8 @@ class DateTime(_DateAffinity, TypeEngine):
     def __init__(self, timezone=False):
         self.timezone = timezone
 
-    def adapt(self, impltype):
-        return impltype(timezone=self.timezone)
+    def adapt(self, impltype, **kw):
+        return impltype(timezone=self.timezone, **kw)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
@@ -1303,8 +1297,8 @@ class Time(_DateAffinity,TypeEngine):
     def __init__(self, timezone=False):
         self.timezone = timezone
 
-    def adapt(self, impltype):
-        return impltype(timezone=self.timezone)
+    def adapt(self, impltype, **kw):
+        return impltype(timezone=self.timezone, **kw)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.DATETIME
@@ -1365,8 +1359,8 @@ class _Binary(TypeEngine):
         else:
             return super(_Binary, self)._coerce_compared_value(op, value)
     
-    def adapt(self, impltype):
-        return impltype(length=self.length)
+    def adapt(self, impltype, **kw):
+        return impltype(length=self.length, **kw)
 
     def get_dbapi_type(self, dbapi):
         return dbapi.BINARY
@@ -1452,7 +1446,7 @@ class SchemaType(object):
         if bind is None:
             bind = schema._bind_or_error(self)
         t = self.dialect_impl(bind.dialect)
-        if t is not self and isinstance(t, SchemaType):
+        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t.create(bind=bind, checkfirst=checkfirst)
 
     def drop(self, bind=None, checkfirst=False):
@@ -1461,27 +1455,27 @@ class SchemaType(object):
         if bind is None:
             bind = schema._bind_or_error(self)
         t = self.dialect_impl(bind.dialect)
-        if t is not self and isinstance(t, SchemaType):
+        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t.drop(bind=bind, checkfirst=checkfirst)
         
     def _on_table_create(self, event, target, bind, **kw):
         t = self.dialect_impl(bind.dialect)
-        if t is not self and isinstance(t, SchemaType):
+        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_table_create(event, target, bind, **kw)
 
     def _on_table_drop(self, event, target, bind, **kw):
         t = self.dialect_impl(bind.dialect)
-        if t is not self and isinstance(t, SchemaType):
+        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_table_drop(event, target, bind, **kw)
 
     def _on_metadata_create(self, event, target, bind, **kw):
         t = self.dialect_impl(bind.dialect)
-        if t is not self and isinstance(t, SchemaType):
+        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_metadata_create(event, target, bind, **kw)
 
     def _on_metadata_drop(self, event, target, bind, **kw):
         t = self.dialect_impl(bind.dialect)
-        if t is not self and isinstance(t, SchemaType):
+        if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
             t._on_metadata_drop(event, target, bind, **kw)
     
 class Enum(String, SchemaType):
@@ -1578,7 +1572,7 @@ class Enum(String, SchemaType):
                     )
         table.append_constraint(e)
         
-    def adapt(self, impltype):
+    def adapt(self, impltype, **kw):
         if issubclass(impltype, Enum):
             return impltype(name=self.name, 
                         quote=self.quote, 
@@ -1586,10 +1580,11 @@ class Enum(String, SchemaType):
                         metadata=self.metadata,
                         convert_unicode=self.convert_unicode,
                         native_enum=self.native_enum,
-                        *self.enums
+                        *self.enums,
+                        **kw
                         )
         else:
-            return super(Enum, self).adapt(impltype)
+            return super(Enum, self).adapt(impltype, **kw)
 
 class PickleType(MutableType, TypeDecorator):
     """Holds Python objects, which are serialized using pickle.
@@ -1791,11 +1786,11 @@ class Interval(_DateAffinity, TypeDecorator):
         self.second_precision = second_precision
         self.day_precision = day_precision
 
-    def adapt(self, cls):
+    def adapt(self, cls, **kw):
         if self.native:
-            return cls._adapt_from_generic_interval(self)
+            return cls._adapt_from_generic_interval(self, **kw)
         else:
-            return self
+            return cls(**kw)
     
     def bind_processor(self, dialect):
         impl_processor = self.impl.bind_processor(dialect)
index 98c894e5b80147882aee4d5090523998e0607240..fd5e1449dae570c1aa3dcfb1c14b425792f96971 100644 (file)
@@ -876,7 +876,6 @@ class ThreadLocalRegistry(ScopedRegistry):
         except AttributeError:
             pass
 
-
 def _iter_id(iterable):
     """Generator: ((id(o), o) for o in iterable)."""
 
index 26b6c7df4b34fd71c856507e741ba3d71bd59e86..e3ed80a74aaa89e15311112381a8ddd69b045acb 100644 (file)
@@ -210,7 +210,33 @@ class MemUsageTest(EnsureZeroed):
         metadata.drop_all()
         del m1, m2, m3
         assert_no_mappers()
-
+    
+    def test_ad_hoc_types(self):
+        """test storage of bind processors, result processors
+        in dialect-wide registry."""
+        
+        from sqlalchemy.dialects import mysql, postgresql, sqlite
+        from sqlalchemy import types
+        
+        for args in (
+            (types.Integer, ),
+            (types.String, ),
+            (types.PickleType, ),
+            (types.Enum, 'a', 'b', 'c'),
+            (sqlite.DATETIME, ),
+            (postgresql.ENUM, 'a', 'b', 'c'),
+            (types.Interval, ),
+            (postgresql.INTERVAL, ),
+            (mysql.VARCHAR, ),
+        ):
+            @profile_memory
+            def go():
+                type_ = args[0](*args[1:])
+                bp = type_._cached_bind_processor(testing.db.dialect)
+                rp = type_._cached_result_processor(testing.db.dialect, 0)
+            go()
+            
+            
     def test_many_updates(self):
         metadata = MetaData(testing.db)