]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- use MutableMapping to make this more succinct, complete
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Feb 2014 05:27:54 +0000 (00:27 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Feb 2014 05:27:54 +0000 (00:27 -0500)
lib/sqlalchemy/sql/base.py

index 818a07deb86b073e785fb121231126d97034d151..260cdab660a180904dde952173b13a4228ec8a84 100644 (file)
@@ -44,40 +44,7 @@ def _generative(fn, *args, **kw):
     return self
 
 
-
-class _DialectArgDictBase(object):
-    """base for dynamic dictionaries that handle dialect-level keyword
-    arguments."""
-
-    def _keys_iter(self):
-        raise NotImplementedError()
-    if util.py2k:
-        def keys(self):
-            return list(self._keys_iter())
-        def items(self):
-            return [(key, self[key]) for key in self._keys_iter()]
-    else:
-        def keys(self):
-            return self._keys_iter()
-        def items(self):
-            return ((key, self[key]) for key in self._keys_iter())
-
-    def get(self, key, default=None):
-        if key in self:
-            return self[key]
-        else:
-            return default
-
-    def __iter__(self):
-        return self._keys_iter()
-
-    def __eq__(self, other):
-        return dict(self) == dict(other)
-
-    def __repr__(self):
-        return repr(dict(self))
-
-class _DialectArgView(_DialectArgDictBase):
+class _DialectArgView(collections.MutableMapping):
     """A dictionary view of dialect-level arguments in the form
     <dialectname>_<argument_name>.
 
@@ -85,10 +52,16 @@ class _DialectArgView(_DialectArgDictBase):
     def __init__(self, obj):
         self.obj = obj
 
-    def __getitem__(self, key):
-        if "_" not in key:
+    def _key(self, key):
+        try:
+            dialect, value_key = key.split("_", 1)
+        except ValueError:
             raise KeyError(key)
-        dialect, value_key = key.split("_", 1)
+        else:
+            return dialect, value_key
+
+    def __getitem__(self, key):
+        dialect, value_key = self._key(key)
 
         try:
             opt = self.obj.dialect_options[dialect]
@@ -98,21 +71,30 @@ class _DialectArgView(_DialectArgDictBase):
             return opt[value_key]
 
     def __setitem__(self, key, value):
-        if "_" not in key:
+        try:
+            dialect, value_key = self._key(key)
+        except KeyError:
             raise exc.ArgumentError(
                             "Keys must be of the form <dialectname>_<argname>")
+        else:
+            self.obj.dialect_options[dialect][value_key] = value
+
+    def __delitem__(self, key):
+        dialect, value_key = self._key(key)
+        del self.obj.dialect_options[dialect][value_key]
 
-        dialect, value_key = key.split("_", 1)
-        self.obj.dialect_options[dialect][value_key] = value
+    def __len__(self):
+        return sum(len(args._non_defaults) for args in
+                            self.obj.dialect_options.values())
 
-    def _keys_iter(self):
+    def __iter__(self):
         return (
             "%s_%s" % (dialect_name, value_name)
             for dialect_name in self.obj.dialect_options
             for value_name in self.obj.dialect_options[dialect_name]._non_defaults
         )
 
-class _DialectArgDict(_DialectArgDictBase):
+class _DialectArgDict(collections.MutableMapping):
     """A dictionary view of dialect-level arguments for a specific
     dialect.
 
@@ -120,11 +102,14 @@ class _DialectArgDict(_DialectArgDictBase):
     and dialect-specified default arguments.
 
     """
-    def __init__(self, obj, dialect_name):
+    def __init__(self):
         self._non_defaults = {}
         self._defaults = {}
 
-    def _keys_iter(self):
+    def __len__(self):
+        return len(set(self._non_defaults).union(self._defaults))
+
+    def __iter__(self):
         return iter(set(self._non_defaults).union(self._defaults))
 
     def __getitem__(self, key):
@@ -136,6 +121,10 @@ class _DialectArgDict(_DialectArgDictBase):
     def __setitem__(self, key, value):
         self._non_defaults[key] = value
 
+    def __delitem__(self, key):
+        del self._non_defaults[key]
+
+
 class DialectKWArgs(object):
     """Establish the ability for a class to have dialect-specific arguments
     with defaults and constructor validation.
@@ -235,7 +224,7 @@ class DialectKWArgs(object):
 
     def _kw_reg_for_dialect_cls(self, dialect_name):
         construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
-        d = _DialectArgDict(self, dialect_name)
+        d = _DialectArgDict()
 
         if construct_arg_dictionary is None:
             d._defaults.update({"*": None})
@@ -288,8 +277,7 @@ class DialectKWArgs(object):
                         "Can't validate argument %r; can't "
                         "locate any SQLAlchemy dialect named %r" %
                         (k, dialect_name))
-                self.dialect_options[dialect_name] = d = \
-                                    _DialectArgDict(self, dialect_name)
+                self.dialect_options[dialect_name] = d = _DialectArgDict()
                 d._defaults.update({"*": None})
                 d._non_defaults[arg_name] = kwargs[k]
             else: