]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
_selectable interface; allows sqlsoup to pass its classes to Join and have the underl...
authorJonathan Ellis <jbellis@gmail.com>
Fri, 21 Jul 2006 16:53:05 +0000 (16:53 +0000)
committerJonathan Ellis <jbellis@gmail.com>
Fri, 21 Jul 2006 16:53:05 +0000 (16:53 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/ext/proxy.py
lib/sqlalchemy/ext/selectresults.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql.py

index 3ba2965769280d922e440f1c7cb94fb5eb8a3b94..7a76eddbe2e58b14e0a9b809d0a04f9eca1a088c 100644 (file)
@@ -309,27 +309,30 @@ class ANSICompiler(sql.Compiled):
             if isinstance(c, sql.Select) and c._scalar:
                 c.accept_visitor(self)
                 inner_columns[self.get_str(c)] = c
-            elif c.is_selectable():
-                for co in c.columns:
-                    if select.use_labels:
-                        l = co.label(co._label)
-                        l.accept_visitor(self)
-                        inner_columns[co._label] = l
-                    # TODO: figure this out, a ColumnClause with a select as a parent
-                    # is different from any other kind of parent
-                    elif select.issubquery and isinstance(co, sql.ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
-                        # SQLite doesnt like selecting from a subquery where the column
-                        # names look like table.colname, so add a label synonomous with
-                        # the column name
-                        l = co.label(co.name)
-                        l.accept_visitor(self)
-                        inner_columns[self.get_str(l.obj)] = l
-                    else:
-                        co.accept_visitor(self)
-                        inner_columns[self.get_str(co)] = co
-            else:
+                continue
+            try:
+                s = c._selectable()
+            except AttributeError:
                 c.accept_visitor(self)
                 inner_columns[self.get_str(c)] = c
+                continue
+            for co in s.columns:
+                if select.use_labels:
+                    l = co.label(co._label)
+                    l.accept_visitor(self)
+                    inner_columns[co._label] = l
+                # TODO: figure this out, a ColumnClause with a select as a parent
+                # is different from any other kind of parent
+                elif select.issubquery and isinstance(co, sql.ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
+                    # SQLite doesnt like selecting from a subquery where the column
+                    # names look like table.colname, so add a label synonomous with
+                    # the column name
+                    l = co.label(co.name)
+                    l.accept_visitor(self)
+                    inner_columns[self.get_str(l.obj)] = l
+                else:
+                    co.accept_visitor(self)
+                    inner_columns[self.get_str(co)] = co
         self.select_stack.pop(-1)
         
         collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
index deced55b4902b39cdee7aa2d5952e12d56e6cb92..60972a6d583f3721d72f5051da0472614840c909 100644 (file)
@@ -68,8 +68,6 @@ class ProxyEngine(BaseProxyEngine):
         BaseProxyEngine.__init__(self)
         # create the local storage for uri->engine map and current engine
         self.storage = local()
-        self.storage.connection = {}
-        self.storage.engine = None
         self.kwargs = kwargs
 
     def connect(self, *args, **kwargs):
@@ -87,13 +85,13 @@ class ProxyEngine(BaseProxyEngine):
             self.storage.engine = None
             map = self.storage.connection
         try:
-            self.engine = map[key]
+            self.storage.engine = map[key]
         except KeyError:
             map[key] = create_engine(*args, **kwargs)
             self.storage.engine = map[key]
             
     def get_engine(self):
-        if self.storage.engine is None:
+        if not hasattr(self.storage, 'engine') or self.storage.engine is None:
             raise AttributeError("No connection established")
         return self.storage.engine
 
index e4ae62162b51b5fdcf3497373d596c8095c74dac..2ad52c8f0a91b35162710322187b2a3036d51e33 100644 (file)
@@ -7,7 +7,7 @@ class SelectResultsExt(orm.MapperExtension):
     def select_by(self, query, *args, **params):
         return SelectResults(query, query.join_by(*args, **params))
     def select(self, query, arg=None, **kwargs):
-        if arg is not None and isinstance(arg, sql.Selectable):
+        if hasattr(arg, '_selectable'):
             return orm.EXT_PASS
         else:
             return SelectResults(query, arg, ops=kwargs)
index 2682739908615e4c2734f221a976de4db87d4dc5..f1d4d4ab2f118d6fc12b4e159212402d2c2785af 100644 (file)
@@ -220,10 +220,12 @@ class Query(object):
         ret = self.extension.select(self, arg=arg, **kwargs)
         if ret is not mapper.EXT_PASS:
             return ret
-        elif arg is not None and isinstance(arg, sql.Selectable):
-            return self.select_statement(arg, **kwargs)
-        else:
+        try:
+            s = arg._selectable_()
+        except AttributeError:
             return self.select_whereclause(whereclause=arg, **kwargs)
+        else:
+            return self.select_statement(s, **kwargs)
 
     def select_whereclause(self, whereclause=None, params=None, **kwargs):
         statement = self.compile(whereclause, **kwargs)
index e0b8866f62a0c2a95bfca2271ea4896ed035896a..3650dd8b831273a382401a7857301e4380cc6a74 100644 (file)
@@ -435,11 +435,6 @@ class ClauseElement(object):
         new structure can then be restructured without affecting the original."""
         return self
 
-    def is_selectable(self):
-        """returns True if this ClauseElement is Selectable, i.e. it contains a list of Column
-        objects and can be used as the target of a select statement."""
-        return False
-
     def _find_engine(self):
         """default strategy for locating an engine within the clause element.
         relies upon a local engine property, or looks in the "from" objects which 
@@ -542,7 +537,7 @@ class CompareMixin(object):
     def in_(self, *other):
         if len(other) == 0:
             return self.__eq__(None)
-        elif len(other) == 1 and not isinstance(other[0], Selectable):
+        elif len(other) == 1 and not hasattr(other[0], '_selectable'):
             return self.__eq__(other[0])
         elif _is_literal(other[0]):
             return self._compare('IN', ClauseList(parens=True, *[self._bind_param(o) for o in other]))
@@ -611,10 +606,10 @@ class CompareMixin(object):
 class Selectable(ClauseElement):
     """represents a column list-holding object."""
 
+    def _selectable(self):
+        return self
     def accept_visitor(self, visitor):
         raise NotImplementedError(repr(self))
-    def is_selectable(self):
-        return True
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
     def _group_parenthesized(self):
@@ -748,11 +743,14 @@ class FromClause(Selectable):
         self._orig_cols = {}
         export = self._exportable_columns()
         for column in export:
-            if column.is_selectable():
-                for co in column.columns:
-                    cp = self._proxy_column(co)
-                    for ci in cp.orig_set:
-                        self._orig_cols[ci] = cp
+            try:
+                s = column._selectable()
+            except AttributeError:
+                continue
+            for co in s.columns:
+                cp = self._proxy_column(co)
+                for ci in cp.orig_set:
+                    self._orig_cols[ci] = cp
         if self.oid_column is not None:
             for ci in self.oid_column.orig_set:
                 self._orig_cols[ci] = self.oid_column
@@ -1014,9 +1012,9 @@ class BinaryClause(ClauseElement):
         self.operator = operator
         self.type = sqltypes.to_instance(type)
         self.parens = False
-        if isinstance(self.left, BinaryClause) or isinstance(self.left, Selectable):
+        if isinstance(self.left, BinaryClause) or hasattr(self.left, '_selectable'):
             self.left.parens = True
-        if isinstance(self.right, BinaryClause) or isinstance(self.right, Selectable):
+        if isinstance(self.right, BinaryClause) or hasattr(self.right, '_selectable'):
             self.right.parens = True
     def copy_container(self):
         return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator)
@@ -1049,10 +1047,10 @@ class BinaryExpression(BinaryClause, ColumnElement):
         
 class Join(FromClause):
     def __init__(self, left, right, onclause=None, isouter = False):
-        self.left = left
-        self.right = right
+        self.left = left._selectable()
+        self.right = right._selectable()
         if onclause is None:
-            self.onclause = self._match_primaries(left, right)
+            self.onclause = self._match_primaries(self.left, self.right)
         else:
             self.onclause = onclause
         self.isouter = isouter