]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
factored oid column into a consistent late-bound pattern, fixing [ticket:146]
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Apr 2006 21:40:18 +0000 (21:40 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Apr 2006 21:40:18 +0000 (21:40 +0000)
lib/sqlalchemy/sql.py

index 532b060a71108e06bfc099d65d21032a80199211..f0171571d45aafed5cfb995b804dcc20883eee8b 100644 (file)
@@ -611,6 +611,13 @@ class FromClause(Selectable):
         return Join(self, right, isouter = True, *args, **kwargs)
     def alias(self, name=None):
         return Alias(self, name)
+    def _locate_oid_column(self):
+        """subclasses override this to return an appropriate OID column"""
+        return None
+    def _get_oid_column(self):
+        if not hasattr(self, '_oid_column'):
+            self._oid_column = self._locate_oid_column()
+        return self._oid_column
     def _get_col_by_original(self, column, raiseerr=True):
         """given a column which is a schema.Column object attached to a schema.Table object
         (i.e. an "original" column), return the Column object from this 
@@ -635,6 +642,7 @@ class FromClause(Selectable):
     primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
     foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys'))
     original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'))
+    oid_column = property(_get_oid_column)
     
     def _export_columns(self):
         if hasattr(self, '_columns'):
@@ -912,7 +920,8 @@ class Join(FromClause):
             self.onclause = onclause
         self.isouter = isouter
 
-    oid_column = property(lambda s:s.left.oid_column)
+    def _locate_oid_column(self):
+        return self.left.oid_column
     
     def _exportable_columns(self):
         return [c for c in self.left.columns] + [c for c in self.right.columns]
@@ -990,11 +999,13 @@ class Alias(FromClause):
                 n = n[0:15]
             alias = n + "_" + hex(random.randint(0, 65535))[2:]
         self.name = alias
+        
+    def _locate_oid_column(self):
         if self.selectable.oid_column is not None:
-            self.oid_column = self.selectable.oid_column._make_proxy(self)
+            return self.selectable.oid_column._make_proxy(self)
         else:
-            self.oid_column = None
-
+            return None
+    
     def _exportable_columns(self):
         return self.selectable.columns
 
@@ -1094,18 +1105,16 @@ class TableClause(FromClause):
     def append_column(self, c):
         self._columns[c.text] = c
         c.table = self
-    def _oid_col(self):
+    def _locate_oid_column(self):
         if self.engine is None:
             return None
-        # OID remains a little hackish so far
-        if not hasattr(self, '_oid_column'):
-            if self.engine.oid_column_name() is not None:
-                self._oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True)
-                self._oid_column._set_parent(self)
-                self._orig_columns()[self._oid_column.original] = self._oid_column
-            else:
-                self._oid_column = None
-        return self._oid_column
+        if self.engine.oid_column_name() is not None:
+            _oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True)
+            _oid_column._set_parent(self)
+            self._orig_columns()[_oid_column.original] = _oid_column
+            return _oid_column
+        else:
+            return None
     def _orig_columns(self):
         try:
             return self._orig_cols
@@ -1119,7 +1128,6 @@ class TableClause(FromClause):
     primary_key = property(lambda s:s._primary_key)
     foreign_keys = property(lambda s:s._foreign_keys)
     original_columns = property(_orig_columns)
-    oid_column = property(_oid_col)
 
     def _clear(self):
         """clears all attributes on this TableClause so that new items can be added again"""
@@ -1193,13 +1201,15 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         self.parens = kwargs.pop('parens', False)
         self.correlate = kwargs.pop('correlate', False)
         self.for_update = kwargs.pop('for_update', False)
-        self.oid_column = selects[0].oid_column
         for s in self.selects:
             s.group_by(None)
             s.order_by(None)
         self.group_by(*kwargs.get('group_by', [None]))
         self.order_by(*kwargs.get('order_by', [None]))
 
+    def _locate_oid_column(self):
+        return self.selects[0].oid_column
+
     def _exportable_columns(self):
         for s in self.selects:
             for c in s.c:
@@ -1235,7 +1245,6 @@ class Select(SelectBaseMixin, FromClause):
         self.whereclause = None
         self.having = None
         self._engine = engine
-        self.oid_column = None
         self.limit = limit
         self.offset = offset
         self.for_update = for_update
@@ -1345,11 +1354,15 @@ class Select(SelectBaseMixin, FromClause):
     def append_from(self, fromclause):
         if type(fromclause) == str:
             fromclause = FromClause(from_name = fromclause)
-        if self.oid_column is None and hasattr(fromclause, 'oid_column'):
-            self.oid_column = fromclause.oid_column
         fromclause.accept_visitor(self._correlator)
         fromclause._process_from_dict(self._froms, True)
-
+    def _locate_oid_column(self):
+        for f in self._froms.values():
+            oid = f.oid_column
+            if oid is not None:
+                return oid
+        else:
+            return None
     def _get_froms(self):
         return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))]
     froms = property(lambda s: s._get_froms())