]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- a few fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Aug 2010 19:23:37 +0000 (15:23 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Aug 2010 19:23:37 +0000 (15:23 -0400)
- what will be the test suite

lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/properties.py
test/ext/test_hybrid.py [new file with mode: 0644]

index ea04b0910e7b6e1cc13c78c784ca02d7023a9295..01f697edd644845c17ceb32f2d8094419d056d16 100644 (file)
@@ -202,12 +202,12 @@ def create_proxied_attribute(descriptor):
                 return getattr(descriptor, attribute)
             except AttributeError:
                 try:
-                    return getattr(self._comparator, attribute)
+                    return getattr(self.comparator, attribute)
                 except AttributeError:
                     raise AttributeError(
                     'Neither %r object nor %r object has an attribute %r' % (
                     type(descriptor).__name__, 
-                    type(self._comparator).__name__, 
+                    type(self.comparator).__name__, 
                     attribute)
                     )
 
index d0de4bbbb09b14b31db9f48b8a6f7a698288184d..09c5042ff39f7b55b317234eaab8e05913197e91 100644 (file)
@@ -242,16 +242,7 @@ class DescriptorProperty(MapperProperty):
     """:class:`MapperProperty` which proxies access to a 
         user-defined descriptor."""
 
-    def set_parent(self, parent, init):
-        if self.descriptor is None:
-            desc = getattr(parent.class_, self.key, None)
-            if parent._is_userland_descriptor(desc):
-                self.descriptor = desc
-        self.parent = parent
-    
     def instrument_class(self, mapper):
-        class_ = self.parent.class_
-        
         from sqlalchemy.ext import hybrid
 
         # hackety hack hack
@@ -262,6 +253,11 @@ class DescriptorProperty(MapperProperty):
             def __init__(self, key):
                 self.key = key
 
+        if self.descriptor is None:
+            desc = getattr(mapper.class_, self.key, None)
+            if mapper._is_userland_descriptor(desc):
+                self.descriptor = desc
+
         if self.descriptor is None:
             def fset(obj, value):
                 setattr(obj, self.name, value)
@@ -371,8 +367,7 @@ class SynonymProperty(DescriptorProperty):
         util.set_creation_order(self)
 
     def _comparator_factory(self, mapper):
-        class_ = self.parent.class_
-        prop = getattr(class_, self.name).property
+        prop = getattr(mapper.class_, self.name).property
 
         if self.comparator_factory:
             comp = self.comparator_factory(prop, mapper)
@@ -381,10 +376,6 @@ class SynonymProperty(DescriptorProperty):
         return comp
 
     def set_parent(self, parent, init):
-        if self.descriptor is None:
-            desc = getattr(parent.class_, self.key, None)
-            if parent._is_userland_descriptor(desc):
-                self.descriptor = desc
         if self.map_column:
             if self.key not in parent.mapped_table.c:
                 raise sa_exc.ArgumentError(
@@ -414,21 +405,12 @@ class SynonymProperty(DescriptorProperty):
 class ComparableProperty(DescriptorProperty):
     """Instruments a Python property for use in query expressions."""
 
-    extension = None
-    
     def __init__(self, comparator_factory, descriptor=None, doc=None):
         self.descriptor = descriptor
         self.comparator_factory = comparator_factory
         self.doc = doc or (descriptor and descriptor.__doc__) or None
         util.set_creation_order(self)
 
-    def set_parent(self, parent, init):
-        if self.descriptor is None:
-            desc = getattr(parent.class_, self.key, None)
-            if parent._is_userland_descriptor(desc):
-                self.descriptor = desc
-        self.parent = parent
-
     def _comparator_factory(self, mapper):
         return self.comparator_factory(self, mapper)
 
diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py
new file mode 100644 (file)
index 0000000..3dfd4c8
--- /dev/null
@@ -0,0 +1,102 @@
+"""
+
+tests for sqlalchemy.ext.hybrid TODO
+
+
+"""
+
+
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext import hybrid
+from sqlalchemy.orm.interfaces import PropComparator
+
+
+"""
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext import hybrid
+
+Base = declarative_base()
+
+
+class UCComparator(hybrid.Comparator):
+    
+    def __eq__(self, other):
+        if other is None:
+            return self.expression == None
+        else:
+            return func.upper(self.expression) == func.upper(other)
+
+class A(Base):
+    __tablename__ = 'a'
+    id = Column(Integer, primary_key=True)
+    _value = Column("value", String)
+
+    @hybrid.property_
+    def value(self):
+        return int(self._value)
+
+    @value.comparator
+    def value(cls):
+        return UCComparator(cls._value)
+        
+    @value.setter
+    def value(self, v):
+        self.value = v
+print aliased(A).value
+print aliased(A).__tablename__
+
+sess = create_session()
+
+print A.value == "foo"
+print sess.query(A.value)
+print sess.query(aliased(A).value)
+print sess.query(aliased(A)).filter_by(value="foo")
+"""
+
+"""
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext import hybrid
+
+Base = declarative_base()
+
+class A(Base):
+    __tablename__ = 'a'
+    id = Column(Integer, primary_key=True)
+    _value = Column("value", String)
+
+    @hybrid.property
+    def value(self):
+        return int(self._value)
+    
+    @value.expression
+    def value(cls):
+        return func.foo(cls._value) + cls.bar_value
+
+    @value.setter
+    def value(self, v):
+        self.value = v
+
+    @hybrid.property
+    def bar_value(cls):
+        return func.bar(cls._value)
+        
+#print A.value
+#print A.value.__doc__
+
+print aliased(A).value
+print aliased(A).__tablename__
+
+sess = create_session()
+
+print sess.query(A).filter_by(value="foo")
+
+print sess.query(aliased(A)).filter_by(value="foo")
+
+
+"""
\ No newline at end of file