]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
changes related to mapping against arbitrary selects, selects with labels or functions:
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Dec 2005 05:58:45 +0000 (05:58 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Dec 2005 05:58:45 +0000 (05:58 +0000)
testfunction has a more complete test (needs an assert tho);
added new labels, synonymous with column key, to "select" statements that are subqueries with use_labels=False, since SQLite wants them -
this also impacts the names of the columns attached to the select object in the case that the key and name dont match, since
it is now the key, not the name;
aliases generate random names if name is None (need some way to make them more predictable to help plan caching);
select statements have a rowid column of None, since there isnt really a "rowid"...at least cant figure out what it would be yet;
mapper creates an alias if given a select to map against, since Postgres wants it;
mapper checks if it has pks for a given table before saving/deleting, skips it otherwise;
mapper will not try to order by rowid if table doesnt have a rowid (since select statements dont have rowids...)

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/sql.py
test/mapper.py
test/select.py

index 8df5e535231f86f36f7abb09853aac9b3060f34b..abbc067515f19efe78122952e85e1ac5a51420d0 100644 (file)
@@ -262,6 +262,13 @@ class ANSICompiler(sql.Compiled):
                         l = co.label(co._label)
                         l.accept_visitor(self)
                         inner_columns[co._label] = l
+                    elif select.issubquery and isinstance(co, Column):
+                        # SQLite doesnt like selecting from a subquery where the column
+                        # names look like table.colname, so add a label synonomous with
+                        # the column name
+                        l = co.label(co.key)
+                        l.accept_visitor(self)
+                        inner_columns[self.get_str(l.obj)] = l
                     else:
                         co.accept_visitor(self)
                         inner_columns[self.get_str(co)] = co
index 41616dceb83b8a94cc9332e10c6d835d2dbcd87b..4c63f5c0cccc30ec413a6413621ebe4732defbae 100644 (file)
@@ -49,7 +49,7 @@ class Mapper(object):
             'primarytable':primarytable,
             'properties':properties or {},
             'primary_key':primary_key,
-            'is_primary':False,
+            'is_primary':None,
             'inherits':inherits,
             'inherit_condition':inherit_condition,
             'extension':extension,
@@ -72,8 +72,13 @@ class Mapper(object):
             primarytable = inherits.primarytable
             # inherit_condition is optional since the join can figure it out
             table = sql.join(table, inherits.table, inherit_condition)
-            
-        self.table = table
+        
+        if isinstance(table, sql.Select):
+            # some db's, noteably postgres, dont want to select from a select
+            # without an alias
+            self.table = table.alias(None)
+        else:
+            self.table = table
         
         # locate all tables contained within the "table" passed in, which
         # may be a join or other construct
@@ -93,9 +98,10 @@ class Mapper(object):
         self.pks_by_table = {}
         if primary_key is not None:
             for k in primary_key:
-                self.pks_by_table.setdefault(k.table, []).append(k)
+                self.pks_by_table.setdefault(k.table, util.HashSet()).append(k)
                 if k.table != self.table:
-                    self.pks_by_table.setdefault(self.table, []).append(k)
+                    # associate pk cols from subtables to the "main" table
+                    self.pks_by_table.setdefault(self.table, util.HashSet()).append(k)
         else:
             for t in self.tables + [self.table]:
                 try:
@@ -122,10 +128,10 @@ class Mapper(object):
         # load custom properties 
         if properties is not None:
             for key, prop in properties.iteritems():
-                if isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement):
+                if is_column(prop):
                     self.columns[key] = prop
                     prop = ColumnProperty(prop)
-                elif isinstance(prop, list) and (isinstance(prop[0], schema.Column) or isinstance(prop[0], sql.ColumnElement)) :
+                elif isinstance(prop, list) and is_column(prop[0]):
                     self.columns[key] = prop[0]
                     prop = ColumnProperty(*prop)
                 self.props[key] = prop
@@ -158,7 +164,11 @@ class Mapper(object):
             proplist = self.columntoproperty.setdefault(column.original, [])
             proplist.append(prop)
 
-        if not hasattr(self.class_, '_mapper') or self.is_primary or not mapper_registry.has_key(self.class_._mapper) or (inherits is not None and inherits._is_primary_mapper()):
+        if (
+                (not hasattr(self.class_, '_mapper') or not mapper_registry.has_key(self.class_._mapper))
+                or self.is_primary 
+                or (inherits is not None and inherits._is_primary_mapper())
+            ):
             objectstore.global_attributes.reset_class_managed(self.class_)
             self._init_class()
             
@@ -166,13 +176,12 @@ class Mapper(object):
             for key, prop in inherits.props.iteritems():
                 if not self.props.has_key(key):
                     self.props[key] = prop._copy()
-                
 
     engines = property(lambda s: [t.engine for t in s.tables])
 
     def add_property(self, key, prop):
         self.copyargs['properties'][key] = prop
-        if (isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement)):
+        if is_column(prop):
             self.columns[key] = prop
             prop = ColumnProperty(prop)
         self.props[key] = prop
@@ -194,7 +203,7 @@ class Mapper(object):
         return self.hashkey
 
     def _is_primary_mapper(self):
-        return getattr(self.class_, '_mapper') == self.hashkey
+        return getattr(self.class_, '_mapper', None) == self.hashkey
         
     def _init_class(self):
         """sets up our classes' overridden __init__ method, this mappers hash key as its
@@ -447,6 +456,9 @@ class Mapper(object):
         list."""
           
         for table in self.tables:
+            if not self._has_pks(table):
+                continue
+
             # loop thru tables in the outer loop, objects on the inner loop.
             # this is important for an object represented across two tables
             # so that it gets its primary key columns populated for the benefit of the
@@ -457,9 +469,8 @@ class Mapper(object):
             # we have our own idea of the primary key columns 
             # for this table, in the case that the user
             # specified custom primary key cols.
-            pk = {}
-            for k in self.pks_by_table[table]:
-                pk[k] = k
+            # also, if we are missing a primary key for this table, then
+            # just skip inserting/updating the table
             for obj in objects:
                 
 #                print "SAVE_OBJ we are " + hash_key(self) + " obj: " +  obj.__class__.__name__ + repr(id(obj))
@@ -471,8 +482,7 @@ class Mapper(object):
 
                 hasdata = False
                 for col in table.columns:
-                    #if col.primary_key:
-                    if pk.has_key(col):
+                    if self.pks_by_table[table].contains(col):
                         if hasattr(obj, "_instance_key"):
                             params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col)
                         else:
@@ -536,6 +546,8 @@ class Mapper(object):
         """called by a UnitOfWork object to delete objects, which involves a
         DELETE statement for each table used by this mapper, for each object in the list."""
         for table in self.tables:
+            if not self._has_pks(table):
+                continue
             delete = []
             for obj in objects:
                 params = {}
@@ -556,6 +568,16 @@ class Mapper(object):
                 if table.engine.supports_sane_rowcount() and c.rowcount != len(delete):
                     raise "ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete))
 
+    def _has_pks(self, table):
+        try:
+            for k in self.pks_by_table[table]:
+                if not self.columntoproperty.has_key(k.original):
+                    return False
+            else:
+                return True
+        except KeyError:
+            return False
+            
     def register_dependencies(self, *args, **kwargs):
         """called by an instance of objectstore.UOWTransaction to register 
         which mappers are dependent on which, as well as DependencyProcessor 
@@ -581,12 +603,10 @@ class Mapper(object):
         if not no_sort:
             if self.order_by:
                 order_by = self.order_by
-#            elif self.table.rowid_column is not None:
- #               order_by = self.table.rowid_column
-  #          else:
-  #              order_by = None
-            else:
+            elif self.table.rowid_column is not None:
                 order_by = self.table.rowid_column
+            else:
+                order_by = None
         else:
             order_by = None
             
@@ -779,6 +799,9 @@ def hash_key(obj):
     else:
         return repr(obj)
 
+def is_column(col):
+    return isinstance(col, schema.Column) or isinstance(col, sql.ColumnElement)
+    
 def mapper_hash_key(class_, table, primarytable = None, properties = None, **kwargs):
     if properties is None:
         properties = {}
index ba7312c12b74efece55ce69cc5f26fd31e43d4a4..e53ee644cb6e70b2a3e23a6208d2bd016826ce5e 100644 (file)
@@ -24,7 +24,6 @@ import sqlalchemy.util as util
 import sqlalchemy.attributes as attributes
 import mapper
 import objectstore
-import random
 
 class ColumnProperty(MapperProperty):
     """describes an object attribute that corresponds to a table column."""
@@ -856,8 +855,7 @@ class Aliasizer(sql.ClauseVisitor):
         try:
             return self.aliases[table]
         except:
-            aliasname = table.name + "_" + hex(random.randint(0, 65535))[2:]
-            return self.aliases.setdefault(table, sql.alias(table, aliasname))
+            return self.aliases.setdefault(table, sql.alias(table))
 
     def visit_compound(self, compound):
         for i in range(0, len(compound.clauses)):
index b0e86259a1e8cca578840c25519240ad150ee1e4..7db60ffb95e4765d86ec0524d946498e7a5d9ea8 100644 (file)
@@ -20,7 +20,7 @@
 import sqlalchemy.schema as schema
 import sqlalchemy.util as util
 import sqlalchemy.types as types
-import string, re
+import string, re, random
 
 __all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
 
@@ -497,7 +497,7 @@ class FromClause(Selectable):
         return Join(self, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):
         return Join(self, right, isouter = True, *args, **kwargs)
-    def alias(self, name):
+    def alias(self, name=None):
         return Alias(self, name)
 
     
@@ -751,11 +751,17 @@ class Alias(FromClause):
         self._columns = util.OrderedProperties()
         self.foreign_keys = []
         if alias is None:
-            alias = id(self)
+            n = getattr(selectable, 'name')
+            if n is None:
+                n = 'anon'
+            alias = n + "_" + hex(random.randint(0, 65535))[2:]
         self.name = alias
         self.id = self.name
         self.count = 0
-        self.rowid_column = self.selectable.rowid_column._make_proxy(self)
+        if self.selectable.rowid_column is not None:
+            self.rowid_column = self.selectable.rowid_column._make_proxy(self)
+        else:
+            self.rowid_column = None
         for co in selectable.columns:
             co._make_proxy(self)
 
@@ -930,7 +936,7 @@ class TableImpl(FromClause):
         return Join(self.table, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):
         return Join(self.table, right, isouter = True, *args, **kwargs)
-    def alias(self, name):
+    def alias(self, name=None):
         return Alias(self.table, name)
     def select(self, whereclause = None, **params):
         return select([self.table], whereclause, **params)
@@ -1082,16 +1088,20 @@ class Select(SelectBaseMixin, FromClause):
 
         for f in column._get_from_objects():
             f.accept_visitor(self._correlator)
-            if self.rowid_column is None and hasattr(f, 'rowid_column') and f.rowid_column is not None:
-                self.rowid_column = f.rowid_column._make_proxy(self)
         column._process_from_dict(self._froms, False)
 
         if column.is_selectable():
+            # if its a column unit, add it to our exported 
+            # list of columns.  this is where "columns" 
+            # attribute of the select object gets populated.
+            # notice we are overriding the names of the column
+            # with either its label or its key, since one or the other
+            # is used when selecting from a select statement (i.e. a subquery)
             for co in column.columns:
                 if self.use_labels:
-                    co._make_proxy(self, name = co._label)
+                    co._make_proxy(self, name=co._label)
                 else:
-                    co._make_proxy(self)
+                    co._make_proxy(self, name=co.key)
             
     def _get_col_by_original(self, column):
         if self.use_labels:
index 90d182b6a7338fc7be06d5aaecec80cef5682aae..cc792109da6b901179d32436f7d373d11e7f2414 100644 (file)
@@ -120,11 +120,13 @@ class MapperTest(MapperSuperTest):
         
         
     def testfunction(self):
-        s = select([users, (users.c.user_id * 2).label('concat'), func.count(users.c.user_id).label('count')], group_by=[c for c in users.c], use_labels=True)
-        m = mapper(User, s.alias('test'))
+        s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
+        users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c])
+        m = mapper(User, s, primarytable=users)
+        print [c.key for c in m.c]
         l = m.select()
-        print [repr(x.__dict__) for x in l]
-        
+        for u in l:
+            print "User", u.user_id, u.user_name, u.concat, u.count
         
     def testmultitable(self):
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
@@ -363,6 +365,7 @@ class LazyTest(MapperSuperTest):
         # use a union all to get a lot of rows to join against
         u2 = users.alias('u2')
         s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
+        print [key for key in s.c.keys()]
         l = m.select(s.c.u2_user_id==User.c.user_id, distinct=True)
         self.assert_result(l, User, *user_all_result)
         
index cba332578828b079c2f764c0a2b015ec478b4191..1fa2fd456b1b6ba4c67fe3a913898f5fa4629d7c 100644 (file)
@@ -79,18 +79,19 @@ myothertable.othername FROM mytable, myothertable")
         #)
 
         s = select([table], table.c.name == 'jack')
+        print [key for key in s.c.keys()]
         self.runtest(
             select(
                 [s],
                 s.c.id == 7
             )
             ,
-        "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid")
+        "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE id = :id")
         
         sq = select([table])
         self.runtest(
             sq.select(),
-            "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable)"
+            "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable)"
         )
         
         sq = subquery(
@@ -100,8 +101,8 @@ myothertable.othername FROM mytable, myothertable")
 
         self.runtest(
             sq.select(sq.c.id == 7), 
-            "SELECT sq.myid, sq.name, sq.description FROM \
-(SELECT mytable.myid, mytable.name, mytable.description FROM mytable) AS sq WHERE sq.myid = :sq_myid"
+            "SELECT sq.id, sq.name, sq.description FROM \
+(SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.id = :sq_id"
         )
         
         sq = subquery(
@@ -368,7 +369,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
     def testcorrelatedsubquery(self):
         self.runtest(
             table.select(table.c.id == select([table2.c.id], table.c.name == table2.c.name)),
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid FROM myothertable WHERE mytable.name = myothertable.othername)"
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid AS id FROM myothertable WHERE mytable.name = myothertable.othername)"
         )
 
         self.runtest(
@@ -380,19 +381,19 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
         s = subquery('sq2', [talias], exists([1], table2.c.id == talias.c.id))
         self.runtest(
             select([s, table])
-            ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid, ta.name, ta.description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable")
+            ,"SELECT sq2.id, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS id, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable")
 
         s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s')
         self.runtest(
             select([users, s.c.street], from_obj=[s]),
-            """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
+            """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
 
     def testin(self):
         self.runtest(select([table], table.c.id.in_(1, 2, 3)),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1, :mytable_myid_2)")
 
         self.runtest(select([table], table.c.id.in_(select([table2.c.id]))),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid FROM myothertable)")
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS id FROM myothertable)")
     
     def testlateargs(self):
         """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments