]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
factored down exportable_columns/flatten_cols/proxy_column/oid_etc_yada down to a...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Apr 2008 18:41:08 +0000 (18:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Apr 2008 18:41:08 +0000 (18:41 +0000)
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/util.py
test/sql/select.py

index a4028c1efa05fab2c39ea7ed9e24510a7d10de00..8a44a718219a3f34938272840e96e2e53fa8a7c9 100644 (file)
@@ -276,11 +276,6 @@ class Table(SchemaItem, expression.TableClause):
         return _get_table_key(self.name, self.schema)
     key = property(key)
 
-    def _export_columns(self, columns=None):
-        # override FromClause's collection initialization logic; Table
-        # implements it differently
-        pass
-
     def _set_primary_key(self, pk):
         if getattr(self, '_primary_key', None) in self.constraints:
             self.constraints.remove(self._primary_key)
index 2ed3b372f2843864b09ed46350f5be9900c8cf05..b45fa4035ecfd93f1687b52d6d3b40063a70b2e9 100644 (file)
@@ -1611,9 +1611,6 @@ class FromClause(Selectable):
     named_with_column=False
     _hide_froms = []
 
-    def __init__(self):
-        self.oid_column = None
-
     def _get_from_objects(self, **modifiers):
         return []
 
@@ -1723,56 +1720,39 @@ class FromClause(Selectable):
                 delattr(self, attr)
 
     def _expr_attr_func(name):
+        get = util.attrgetter(name)
         def attr(self):
             try:
-                return getattr(self, name)
+                return get(self)
             except AttributeError:
                 self._export_columns()
-                return getattr(self, name)
+                return get(self)
         return property(attr)
-
+    
     columns = c = _expr_attr_func('_columns')
     primary_key = _expr_attr_func('_primary_key')
     foreign_keys = _expr_attr_func('_foreign_keys')
+    oid_column = _expr_attr_func('_oid_column')
 
-    def _export_columns(self, columns=None):
+    def _export_columns(self):
         """Initialize column collections."""
 
-        if hasattr(self, '_columns') and columns is None:
+        if hasattr(self, '_columns'):
             return
         self._columns = ColumnCollection()
         self._primary_key = ColumnSet()
         self._foreign_keys = util.Set()
+        self._oid_column = None
+        self._populate_column_collection()
 
-        if columns is None:
-            columns = self._flatten_exportable_columns()
-        for co in columns:
-            cp = self._proxy_column(co)
-
-    def _flatten_exportable_columns(self):
-        """Return the list of ColumnElements represented within this FromClause's _exportable_columns"""
-        export = self._exportable_columns()
-        for column in export:
-            if isinstance(column, Selectable):
-                for co in column.columns:
-                    yield co
-            elif isinstance(column, ColumnElement):
-                yield column
-            else:
-                continue
-
-    def _exportable_columns(self):
-        return []
-
-    def _proxy_column(self, column):
-        return column._make_proxy(self)
+    def _populate_column_collection(self):
+        pass
 
 class _TextFromClause(FromClause):
     __visit_name__ = 'fromclause'
 
     def __init__(self, text):
         self.name = text
-        self.oid_column = None
 
 class _BindParamClause(ClauseElement, _CompareMixin):
     """Represent a bind parameter.
@@ -2079,7 +2059,6 @@ class _Function(_CalculatedClause, FromClause):
 
     def __init__(self, name, *clauses, **kwargs):
         self.packagenames = kwargs.get('packagenames', None) or []
-        self.oid_column = None
         self.name = name
         self._bind = kwargs.get('bind', None)
         args = [_literal_as_binds(c, self.name) for c in clauses]
@@ -2255,7 +2234,6 @@ class Join(FromClause):
         self.left = _selectable(left)
         self.right = _selectable(right).self_group()
 
-        self.oid_column = self.left.oid_column
         if onclause is None:
             self.onclause = self._match_primaries(self.left, self.right)
         else:
@@ -2263,22 +2241,6 @@ class Join(FromClause):
         self.isouter = isouter
         self.__folded_equivalents = None
 
-    def _export_columns(self):
-        if hasattr(self, '_columns'):
-            return
-        self._columns = ColumnCollection()
-        self._foreign_keys = util.Set()
-
-        columns = list(self._flatten_exportable_columns())
-
-        global sql_util
-        if not sql_util:
-            from sqlalchemy.sql import util as sql_util
-        self._primary_key = sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause)
-
-        for co in columns:
-            cp = self._proxy_column(co)
-
     def description(self):
         return "Join object on %s(%d) and %s(%d)" % (self.left.description, id(self.left), self.right.description, id(self.right))
     description = property(description)
@@ -2289,14 +2251,16 @@ class Join(FromClause):
     def self_group(self, against=None):
         return _FromGrouping(self)
 
-    def _exportable_columns(self):
-        return [c for c in self.left.columns] + [c for c in self.right.columns]
+    def _populate_column_collection(self):
+        columns = [c for c in self.left.columns] + [c for c in self.right.columns]
 
-    def _proxy_column(self, column):
-        self._columns[column._label] = column
-        for f in column.foreign_keys:
-            self._foreign_keys.add(f)
-        return column
+        global sql_util
+        if not sql_util:
+            from sqlalchemy.sql import util as sql_util
+        self._primary_key.extend(sql_util.reduce_columns([c for c in columns if c.primary_key], self.onclause))
+        self._columns.update([(col._label, col) for col in columns])
+        self._foreign_keys.update(itertools.chain(*[col.foreign_keys for col in columns]))    
+        self._oid_column = self.left.oid_column
 
     def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
@@ -2452,10 +2416,6 @@ class Alias(FromClause):
                 alias = getattr(self.original, 'name', None)
             alias = '{ANON %d %s}' % (id(self), alias or 'anon')
         self.name = alias
-        if self.selectable.oid_column is not None:
-            self.oid_column = self.selectable.oid_column._make_proxy(self)
-        else:
-            self.oid_column = None
 
     def description(self):
         return self.name.encode('ascii', 'backslashreplace')
@@ -2472,9 +2432,11 @@ class Alias(FromClause):
     def _table_iterator(self):
         return self.original._table_iterator()
 
-    def _exportable_columns(self):
-        #return self.selectable._exportable_columns()
-        return self.selectable.columns
+    def _populate_column_collection(self):
+        for col in self.selectable.columns:
+            col._make_proxy(self)
+        if self.selectable.oid_column is not None:
+            self._oid_column = self.selectable.oid_column._make_proxy(self)
 
     def _copy_internals(self, clone=_clone):
        self._clone_from_clause()
@@ -2736,8 +2698,15 @@ class TableClause(FromClause):
     def __init__(self, name, *columns):
         super(TableClause, self).__init__()
         self.name = self.fullname = name
-        self.oid_column = _ColumnClause('oid', self, _is_oid=True)
-        self._export_columns(columns)
+        self._oid_column = _ColumnClause('oid', self, _is_oid=True)
+        self._columns = ColumnCollection()
+        self._primary_key = ColumnSet()
+        self._foreign_keys = util.Set()
+        for c in columns:
+            self.append_column(c)
+        
+    def _export_columns(self):
+        raise NotImplementedError()
 
     def description(self):
         return self.name.encode('ascii', 'backslashreplace')
@@ -2751,19 +2720,12 @@ class TableClause(FromClause):
         self._columns[c.name] = c
         c.table = self
 
-    def _proxy_column(self, c):
-        self.append_column(c)
-        return c
-
     def get_children(self, column_collections=True, **kwargs):
         if column_collections:
             return [c for c in self.c]
         else:
             return []
 
-    def _exportable_columns(self):
-        raise NotImplementedError()
-
     def count(self, whereclause=None, **params):
         if self.primary_key:
             col = list(self.primary_key)[0]
@@ -2994,43 +2956,22 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
 
         _SelectBaseMixin.__init__(self, **kwargs)
         
-        self.oid_column = None
-        for s in self.selects:
-            # TODO: need to repair proxy_column here to 
-            # not require full traversal
-            if s.oid_column:
-                self.oid_column = self._proxy_column(s.oid_column)
-    
     def self_group(self, against=None):
         return _FromGrouping(self)
 
-    def _exportable_columns(self):
-        for s in self.selects:
-            for c in s.c:
-                yield c
+    def _populate_column_collection(self):
+        for cols in zip(*[s.c for s in self.selects]):
+            proxy = cols[0]._make_proxy(self, name=self.use_labels and cols[0]._label or None)
+            proxy.proxies = cols
 
-    def _proxy_column(self, column):
-        if not hasattr(self, '_col_map'):
-            self._col_map = dict([(s, []) for s in self.selects])
-            for s in self.selects:
-                for c in s.c + [s.oid_column]:
-                    self._col_map[c] = s
-        
-        selectable = self._col_map[column]
-        col_ordering = self._col_map[selectable]
-        
-        if selectable is self.selects[0]:
-            if self.use_labels:
-                col = column._make_proxy(self, name=column._label)
-            else:
-                col = column._make_proxy(self)
-            col_ordering.append(col)
-            return col
-        else:
-            col_ordering.append(column)
-            existing = self._col_map[self.selects[0]][len(col_ordering) - 1]
-            existing.proxies.append(column)
-            return existing
+        oid_proxies = [
+            c for c in [f.oid_column for f in self.selects] if c is not None
+        ]
+
+        if oid_proxies:
+            col = oid_proxies[0]._make_proxy(self)
+            col.proxies = oid_proxies
+            self._oid_column = col
 
     def _copy_internals(self, clone=_clone):
         self._clone_from_clause()
@@ -3410,15 +3351,30 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             self._froms = util.Set(list(self._froms) + [fromclause])
 
-    def _exportable_columns(self):
-        return [c for c in self._raw_columns if isinstance(c, (Selectable, ColumnElement))]
+    def __exportable_columns(self):
+        for column in self._raw_columns:
+            if isinstance(column, Selectable):
+                for co in column.columns:
+                    yield co
+            elif isinstance(column, ColumnElement):
+                yield column
+            else:
+                continue
 
-    def _proxy_column(self, column):
-        if self.use_labels:
-            return column._make_proxy(self, name=column._label)
-        else:
-            return column._make_proxy(self)
+    def _populate_column_collection(self):
+        for c in self.__exportable_columns():
+            c._make_proxy(self, name=self.use_labels and c._label or None)
 
+        oid_proxies = [c for c in 
+            [f.oid_column for f in self.locate_all_froms()
+            if f is not self] if c is not None
+        ]
+
+        if oid_proxies:
+            col = oid_proxies[0]._make_proxy(self)
+            col.proxies = oid_proxies
+            self._oid_column = col
+    
     def self_group(self, against=None):
         """return a 'grouping' construct as per the ClauseElement specification.
 
@@ -3430,30 +3386,6 @@ class Select(_SelectBaseMixin, FromClause):
             return self
         return _FromGrouping(self)
 
-    def oid_column(self):
-        if hasattr(self, '_oid_column'):
-            return self._oid_column
-
-        proxies = []
-        for f in self.locate_all_froms():
-            if f is self:
-                continue
-            oid = f.oid_column
-            if oid is not None:
-                proxies.append(oid)
-
-        if proxies:
-            # create a proxied column which will act as a proxy
-            # for every OID we've located...
-            col = self._proxy_column(proxies[0])
-            col.proxies = proxies
-            self._oid_column = col
-            return col
-        else:
-            self._oid_column = None
-            return self._oid_column
-    oid_column = property(oid_column)
-
     def union(self, other, **kwargs):
         """return a SQL UNION of this select() construct against the given selectable."""
 
@@ -3501,7 +3433,7 @@ class Select(_SelectBaseMixin, FromClause):
                 return e
         # look through the columns (largely synomous with looking
         # through the FROMs except in the case of _CalculatedClause/_Function)
-        for c in self._exportable_columns():
+        for c in self._raw_columns:
             if getattr(c, 'table', None) is self:
                 continue
             e = c.bind
index 66954168c501ff19b9677c5a057a2b5985f5ddb7..9e9a45976503d660304982d123a01fe707d1b106 100644 (file)
@@ -18,7 +18,6 @@ class GenericFunction(_Function):
 
     def __init__(self, type_=None, group=True, args=(), **kwargs):
         self.packagenames = []
-        self.oid_column = None
         self.name = self.__class__.__name__
         self._bind = kwargs.get('bind', None)
         if group:
index 101ef1462c9c0d792508e03ae3256746af00bdd1..36b40a04de2fdf420b755e595c297514f3dc2c04 100644 (file)
@@ -495,7 +495,10 @@ class OrderedProperties(object):
 
     def __contains__(self, key):
         return key in self._data
-
+    
+    def update(self, value):
+        self._data.update(value)
+        
     def get(self, key, default=None):
         if key in self:
             return self[key]
index 77926b421bed42ea8b0b27db8ea635d075d6ff23..e6c186bcb22f3544365a31925dfab646c3739206 100644 (file)
@@ -829,6 +829,9 @@ FROM mytable UNION SELECT myothertable.otherid, myothertable.othername \
 FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thirdtable")
 
         assert u1.corresponding_column(table2.c.otherid) is u1.c.myid
+        
+        assert u1.corresponding_column(table1.oid_column) is u1.oid_column
+        assert u1.corresponding_column(table2.oid_column) is u1.oid_column
 
         # TODO - why is there an extra space before the LIMIT ?
         self.assert_compile(