]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improvements to relational algrebra of Alias, Select, Join objects, so that they
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Jan 2006 01:26:47 +0000 (01:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Jan 2006 01:26:47 +0000 (01:26 +0000)
all report their column lists, primary key, foreign key lists consistently
and so that ForeignKey objects can line up tables against relational objects

lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py

index c5054cd500dda74d23166940c6f700bec148cdd2..1d5b504e2efd9dd2dec3c530d8a4194c3e1702e4 100644 (file)
@@ -281,7 +281,11 @@ class Column(SchemaItem):
         c._orig = self.original
         if not c.hidden:
             selectable.columns[c.key] = c
+            if self.primary_key:
+                selectable.primary_key.append(c)
         c._impl = self.engine.columnimpl(c)
+        if fk is not None:
+            c._init_items(fk)
         return c
 
     def accept_visitor(self, visitor):
@@ -325,17 +329,17 @@ class ForeignKey(SchemaItem):
         if isinstance(self._colspec, str):
             return ForeignKey(self._colspec)
         else:
-            return ForeignKey("%s.%s" % (self._colspec.table.name, self._colspec.column.key))
+            if self._colspec.table.schema is not None:
+                return ForeignKey("%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.column.key))
+            else:
+                return ForeignKey("%s.%s" % (self._colspec.table.name, self._colspec.column.key))
     
     def references(self, table):
         """returns True if the given table is referenced by this ForeignKey."""
-        return (
-            # simple test
-            self.column.table is table      
-            or
-            # test for an indirect relation via a Selectable
-            table._get_col_by_original(self.column) is not None
-        )
+        try:
+            return table._get_col_by_original(self.column) is not None
+        except:
+            x = self._init_column()
         
     def _init_column(self):
         # ForeignKey inits its remote column as late as possible, so tables can
@@ -347,8 +351,8 @@ class ForeignKey(SchemaItem):
                     raise ValueError("Invalid foreign key column specification: " + self._colspec)
                 if m.group(3) is None:
                     (tname, colname) = m.group(1, 2)
-                    # default to containing table's schema
-                    schema = self.parent.table.schema
+                    # use default schema
+                    schema = None
                 else:
                     (schema,tname,colname) = m.group(1,2,3)
                 table = Table(tname, self.parent.engine, mustexist=True, schema=schema)
index 2a1b6829f8231a86bda2f16d2805a61a89f206cb..e08c644f47c15886f16ccdbd22b422cb24fb5578 100644 (file)
@@ -340,7 +340,6 @@ class ClauseElement(object):
             return None
             
     engine = property(lambda s: s._find_engine())
-
     
     def compile(self, engine = None, parameters = None, typemap=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
@@ -449,16 +448,8 @@ class CompareMixin(object):
 class Selectable(ClauseElement):
     """represents a column list-holding object."""
 
-    def _get_columns(self):
-        try:
-            return self._columns
-        except AttributeError:
-            return [self]
-    columns = property(lambda s: s._get_columns())
-    c = property(lambda self: self.columns)
-
     def accept_visitor(self, visitor):
-        raise NotImplementedError()
+        raise NotImplementedError(repr(self))
 
     def is_selectable(self):
         return True
@@ -466,22 +457,17 @@ class Selectable(ClauseElement):
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
-    def _get_col_by_original(self, column):
-        """given a column which is a schema.Column object attached to a schema.Table object
-        (i.e. an "original" column), return the Column object from this 
-        Selectable which corresponds to that original Column, or None if this Selectable
-        does not contain the column."""
-        raise NotImplementedError()
-
     def _group_parenthesized(self):
         """indicates if this Selectable requires parenthesis when grouped into a compound
         statement"""
         return True
 
+
 class ColumnElement(Selectable, CompareMixin):
     """represents a column element within the list of a Selectable's columns."""
-    primary_key = property(lambda s:False)
-    original = property(lambda self:self)
+    primary_key = property(lambda s:getattr(self, '_primary_key', False))
+    original = property(lambda self:getattr(self, '_original', self))
+    columns = property(lambda self:[self])
 
 class FromClause(Selectable):
     """represents an element that can be used within the FROM clause of a SELECT statement."""
@@ -508,7 +494,42 @@ class FromClause(Selectable):
         return Join(self, right, isouter = True, *args, **kwargs)
     def alias(self, name=None):
         return Alias(self, name)
-
+    def _get_col_by_original(self, column):
+        """given a column which is a schema.Column object attached to a schema.Table object
+        (i.e. an "original" column), return the Column object from this 
+        Selectable which corresponds to that original Column, or None if this Selectable
+        does not contain the column."""
+        return self.original_columns[column.original]
+    def _get_exported_attribute(self, name):
+        try:
+            return getattr(self, name)
+        except AttributeError:
+            self._export_columns()
+            return getattr(self, name)
+    columns = property(lambda s:s._get_exported_attribute('_columns'))
+    c = property(lambda s:s._get_exported_attribute('_columns'))
+    primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
+    foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys'))
+    original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'))
+    
+    def _export_columns(self):
+        if hasattr(self, '_columns'):
+            # TODO: put a mutex here ?  this is a key place for threading probs
+            return
+        self._columns = util.OrderedProperties()
+        self._primary_key = []
+        self._foreign_keys = []
+        self._orig_cols = {}
+        export = self._exportable_columns()
+        for column in export:
+            if column.is_selectable():
+                for co in column.columns:
+                    cp = self._proxy_column(co)
+                    self._orig_cols[co.original] = cp
+    def _exportable_columns(self):
+        raise NotImplementedError(repr(self))
+    def _proxy_column(self, column):
+        return column._make_proxy(self)
     
 class BindParamClause(ClauseElement, CompareMixin):
     """represents a bind parameter.  public constructor is the bindparam() function."""
@@ -708,16 +729,10 @@ class BinaryClause(ClauseElement, CompareMixin):
         )
         
 class Join(FromClause):
-    # TODO: put "using" + "natural" concepts in here and make "onclause" optional
-    def __init__(self, left, right, onclause=None, isouter = False, allcols = True):
+    def __init__(self, left, right, onclause=None, isouter = False):
         self.left = left
         self.right = right
         self.id = self.left.id + "_" + self.right.id
-        self.allcols = allcols
-        if allcols:
-            self._columns = [c for c in self.left.columns] + [c for c in self.right.columns]
-        else:
-            self._columns = self.right.columns
 
         # TODO: if no onclause, do NATURAL JOIN
         if onclause is None:
@@ -726,9 +741,15 @@ class Join(FromClause):
             self.onclause = onclause
         self.isouter = isouter
         self.oid_column = self.left.oid_column
-        
-    primary_key = property (lambda self: [c for c in self.left.columns if c.primary_key] + [c for c in self.right.columns if c.primary_key])
-
+    def _exportable_columns(self):
+        return [c for c in self.left.columns] + [c for c in self.right.columns]
+    def _proxy_column(self, column):
+        self._columns[column.table.name + "_" + column.key] = column
+        if column.primary_key:
+            self._primary_key.append(column)
+        if column.foreign_key:
+            self._foreign_keys.append(column)
+        return column
     def _match_primaries(self, primary, secondary):
         crit = []
         for fk in secondary.foreign_keys:
@@ -752,13 +773,6 @@ class Join(FromClause):
         statement"""
         return True
 
-    def _get_col_by_original(self, column):
-        for c in self.columns:
-            if c.original is column.original:
-                return c
-        else:
-            return None
-
     def hash_key(self):
         return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
 
@@ -792,8 +806,9 @@ class Join(FromClause):
         
 class Alias(FromClause):
     def __init__(self, selectable, alias = None):
+        while isinstance(selectable, Alias):
+            selectable = selectable.selectable
         self.selectable = selectable
-        self._columns = util.OrderedProperties()
         if alias is None:
             n = getattr(selectable, 'name')
             if n is None:
@@ -806,22 +821,13 @@ class Alias(FromClause):
             self.oid_column = self.selectable.oid_column._make_proxy(self)
         else:
             self.oid_column = None
-        for co in selectable.columns:
-            co._make_proxy(self)
 
-    primary_key = property (lambda self: [c for c in self.columns if c.primary_key])
-    foreign_keys = property(lambda s:s.selectable.foreign_keys)
+    def _exportable_columns(self):
+        return self.selectable.columns
 
     def hash_key(self):
         return "Alias(%s, %s)" % (repr(self.selectable.hash_key()), repr(self.name))
 
-    def _get_col_by_original(self, column):
-        c = self.columns.get(column.key, None)
-        if c is not None and c.original is column.original:
-            return c
-        else:
-            return None
-
     def accept_visitor(self, visitor):
         self.selectable.accept_visitor(visitor)
         visitor.visit_alias(self)
@@ -850,7 +856,9 @@ class Label(ColumnElement):
     def _get_from_objects(self):
         return self.obj._get_from_objects()
     def _make_proxy(self, selectable, name = None):
-        cc = ColumnClause(self.name)
+        # TODO: this make_proxy needs foreign_key, primary_key support
+        # from its underlying column, if any
+        cc = ColumnClause(self.name, selectable)
         selectable.c[self.name] = cc
         return cc
      
@@ -903,7 +911,6 @@ class ColumnImpl(ColumnElement):
     def __init__(self, column):
         self.column = column
         self.name = column.name
-        self._columns = [self.column]
         
         if column.table.name:
             self._label = column.table.name + "_" + self.column.name
@@ -912,6 +919,8 @@ class ColumnImpl(ColumnElement):
 
     engine = property(lambda s: s.column.engine)
     default_label = property(lambda s:s._label)
+    original = property(lambda self:self.column)
+    columns = property(lambda self:[self.column])
     
     def label(self, name):
         return Label(name, self.column)
@@ -923,12 +932,6 @@ class ColumnImpl(ColumnElement):
         """compares this ColumnImpl's column to the other given Column"""
         return self.column is other
         
-    def _get_col_by_original(self, column):
-        if self.column.original is column.original:
-            return self.column
-        else:
-            return None
-            
     def _group_parenthesized(self):
         return False
         
@@ -940,7 +943,7 @@ class ColumnImpl(ColumnElement):
             return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type)
         else:
             return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type)
-            
+
     def _compare(self, operator, obj):
         if _is_literal(obj):
             if obj is None:
@@ -952,6 +955,13 @@ class ColumnImpl(ColumnElement):
 
         return BinaryClause(self.column, obj, operator, type=self.column.type)
 
+    def compile(self, engine = None, parameters = None, typemap=None):
+        if engine is None:
+            engine = self.engine
+        if engine is None:
+            raise "no SQLEngine could be located within this ClauseElement."
+        return engine.compile(self.column, parameters=parameters, typemap=typemap)
+
 class TableImpl(FromClause):
     """attached to a schema.Table to provide it with a Selectable interface
     as well as other functions
@@ -971,21 +981,25 @@ class TableImpl(FromClause):
                 self._oid_column = None
         return self._oid_column
 
+    def _orig_columns(self):
+        try:
+            return self._orig_cols
+        except AttributeError:
+            self._orig_cols= {}
+            for c in self.columns:
+                self._orig_cols[c.original] = c
+            return self._orig_cols
+            
     oid_column = property(_oid_col)
     engine = property(lambda s: s.table.engine)
     columns = property(lambda self: self.table.columns)
     primary_key = property(lambda self:self.table.primary_key)
-    
-    def _get_col_by_original(self, column):
-        try:
-          col = self.columns[column.key]
-        except KeyError:
-          return None
-        if col.original is column.original:
-          return col
-        else:
-          return None
+    foreign_keys = property(lambda self:self.table.foreign_keys)
+    original_columns = property(_orig_columns)
 
+    def _exportable_columns(self):
+        raise NotImplementedError()
+        
     def _group_parenthesized(self):
         return False
 
@@ -1065,9 +1079,14 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         if group_by:
             self.group_by(*group_by)
 
-    primary_key = property(lambda s:s.selects[0].primary_key)
-    foreign_keys = property(lambda s:s.selects[0].foreign_keys)
-    columns = property(lambda s:s.selects[0].columns)
+    def _exportable_columns(self):
+        return self.selects[0].columns
+    def _proxy_column(self, column):
+        self._columns[column.key] = column
+        if column.primary_key:
+            self._primary_key.append(column)
+        if column.foreign_key:
+            self._foreign_keys.append(column)
     def accept_visitor(self, visitor):
         for tup in self.clauses:
             tup[1].accept_visitor(visitor)
@@ -1086,7 +1105,6 @@ class Select(SelectBaseMixin, FromClause):
     """represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
     def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None):
-        self._columns = util.OrderedProperties()
         self._froms = util.OrderedDict()
         self.use_labels = use_labels
         self.id = "Select(%d)" % id(self)
@@ -1155,45 +1173,14 @@ class Select(SelectBaseMixin, FromClause):
             f.accept_visitor(self._correlator)
         column._process_from_dict(self._froms, False)
 
-        if column.is_selectable():
-            # if its a column unit, add it to our exported 
-            # list of columns.  this is where "columns" 
-            # attribute of the select object gets populated.
-            # notice we are overriding the names of the column
-            # with either its label or its key, since one or the other
-            # is used when selecting from a select statement (i.e. a subquery)
-            for co in column.columns:
-                if self.use_labels:
-                    co._make_proxy(self, name=co._label)
-                else:
-                    co._make_proxy(self, name=co.key)
-            
-    def _get_col_by_original(self, column):
+    def _exportable_columns(self):
+        return self._raw_columns
+    def _proxy_column(self, column):
         if self.use_labels:
-            c = self.columns.get(column._label,None)
-        else:
-            c = self.columns.get(column.key,None)
-        if c is not None and c.original is column.original:
-            return c
+            return column._make_proxy(self, name=column._label)
         else:
-            return None
-
-    def _pks(self):
-        ret = {}
-        for from_obj in self._get_froms():
-            for c in from_obj.c:
-                if c.primary_key:
-                    ret[c] = c
-        return ret.keys()
-    def _fks(self):
-        ret = []
-        for from_obj in self._get_froms():
-            for fk in from_obj.foreign_keys:
-                ret.append(fk)
-        return ret
-    primary_key = property (_pks)
-    foreign_keys = property(_fks)
-    
+            return column._make_proxy(self, name=column.key)
+            
     def append_whereclause(self, whereclause):
         self._append_condition('whereclause', whereclause)
     def append_having(self, having):