From: Mike Bayer Date: Fri, 30 Dec 2005 05:58:45 +0000 (+0000) Subject: changes related to mapping against arbitrary selects, selects with labels or functions: X-Git-Tag: rel_0_1_0~195 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5ceef4809d2eeb5030eb25668064b0a4a6262eba;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git changes related to mapping against arbitrary selects, selects with labels or functions: testfunction has a more complete test (needs an assert tho); added new labels, synonymous with column key, to "select" statements that are subqueries with use_labels=False, since SQLite wants them - this also impacts the names of the columns attached to the select object in the case that the key and name dont match, since it is now the key, not the name; aliases generate random names if name is None (need some way to make them more predictable to help plan caching); select statements have a rowid column of None, since there isnt really a "rowid"...at least cant figure out what it would be yet; mapper creates an alias if given a select to map against, since Postgres wants it; mapper checks if it has pks for a given table before saving/deleting, skips it otherwise; mapper will not try to order by rowid if table doesnt have a rowid (since select statements dont have rowids...) --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 8df5e53523..abbc067515 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -262,6 +262,13 @@ class ANSICompiler(sql.Compiled): l = co.label(co._label) l.accept_visitor(self) inner_columns[co._label] = l + elif select.issubquery and isinstance(co, Column): + # SQLite doesnt like selecting from a subquery where the column + # names look like table.colname, so add a label synonomous with + # the column name + l = co.label(co.key) + l.accept_visitor(self) + inner_columns[self.get_str(l.obj)] = l else: co.accept_visitor(self) inner_columns[self.get_str(co)] = co diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 41616dceb8..4c63f5c0cc 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -49,7 +49,7 @@ class Mapper(object): 'primarytable':primarytable, 'properties':properties or {}, 'primary_key':primary_key, - 'is_primary':False, + 'is_primary':None, 'inherits':inherits, 'inherit_condition':inherit_condition, 'extension':extension, @@ -72,8 +72,13 @@ class Mapper(object): primarytable = inherits.primarytable # inherit_condition is optional since the join can figure it out table = sql.join(table, inherits.table, inherit_condition) - - self.table = table + + if isinstance(table, sql.Select): + # some db's, noteably postgres, dont want to select from a select + # without an alias + self.table = table.alias(None) + else: + self.table = table # locate all tables contained within the "table" passed in, which # may be a join or other construct @@ -93,9 +98,10 @@ class Mapper(object): self.pks_by_table = {} if primary_key is not None: for k in primary_key: - self.pks_by_table.setdefault(k.table, []).append(k) + self.pks_by_table.setdefault(k.table, util.HashSet()).append(k) if k.table != self.table: - self.pks_by_table.setdefault(self.table, []).append(k) + # associate pk cols from subtables to the "main" table + self.pks_by_table.setdefault(self.table, util.HashSet()).append(k) else: for t in self.tables + [self.table]: try: @@ -122,10 +128,10 @@ class Mapper(object): # load custom properties if properties is not None: for key, prop in properties.iteritems(): - if isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement): + if is_column(prop): self.columns[key] = prop prop = ColumnProperty(prop) - elif isinstance(prop, list) and (isinstance(prop[0], schema.Column) or isinstance(prop[0], sql.ColumnElement)) : + elif isinstance(prop, list) and is_column(prop[0]): self.columns[key] = prop[0] prop = ColumnProperty(*prop) self.props[key] = prop @@ -158,7 +164,11 @@ class Mapper(object): proplist = self.columntoproperty.setdefault(column.original, []) proplist.append(prop) - if not hasattr(self.class_, '_mapper') or self.is_primary or not mapper_registry.has_key(self.class_._mapper) or (inherits is not None and inherits._is_primary_mapper()): + if ( + (not hasattr(self.class_, '_mapper') or not mapper_registry.has_key(self.class_._mapper)) + or self.is_primary + or (inherits is not None and inherits._is_primary_mapper()) + ): objectstore.global_attributes.reset_class_managed(self.class_) self._init_class() @@ -166,13 +176,12 @@ class Mapper(object): for key, prop in inherits.props.iteritems(): if not self.props.has_key(key): self.props[key] = prop._copy() - engines = property(lambda s: [t.engine for t in s.tables]) def add_property(self, key, prop): self.copyargs['properties'][key] = prop - if (isinstance(prop, schema.Column) or isinstance(prop, sql.ColumnElement)): + if is_column(prop): self.columns[key] = prop prop = ColumnProperty(prop) self.props[key] = prop @@ -194,7 +203,7 @@ class Mapper(object): return self.hashkey def _is_primary_mapper(self): - return getattr(self.class_, '_mapper') == self.hashkey + return getattr(self.class_, '_mapper', None) == self.hashkey def _init_class(self): """sets up our classes' overridden __init__ method, this mappers hash key as its @@ -447,6 +456,9 @@ class Mapper(object): list.""" for table in self.tables: + if not self._has_pks(table): + continue + # loop thru tables in the outer loop, objects on the inner loop. # this is important for an object represented across two tables # so that it gets its primary key columns populated for the benefit of the @@ -457,9 +469,8 @@ class Mapper(object): # we have our own idea of the primary key columns # for this table, in the case that the user # specified custom primary key cols. - pk = {} - for k in self.pks_by_table[table]: - pk[k] = k + # also, if we are missing a primary key for this table, then + # just skip inserting/updating the table for obj in objects: # print "SAVE_OBJ we are " + hash_key(self) + " obj: " + obj.__class__.__name__ + repr(id(obj)) @@ -471,8 +482,7 @@ class Mapper(object): hasdata = False for col in table.columns: - #if col.primary_key: - if pk.has_key(col): + if self.pks_by_table[table].contains(col): if hasattr(obj, "_instance_key"): params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col) else: @@ -536,6 +546,8 @@ class Mapper(object): """called by a UnitOfWork object to delete objects, which involves a DELETE statement for each table used by this mapper, for each object in the list.""" for table in self.tables: + if not self._has_pks(table): + continue delete = [] for obj in objects: params = {} @@ -556,6 +568,16 @@ class Mapper(object): if table.engine.supports_sane_rowcount() and c.rowcount != len(delete): raise "ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)) + def _has_pks(self, table): + try: + for k in self.pks_by_table[table]: + if not self.columntoproperty.has_key(k.original): + return False + else: + return True + except KeyError: + return False + def register_dependencies(self, *args, **kwargs): """called by an instance of objectstore.UOWTransaction to register which mappers are dependent on which, as well as DependencyProcessor @@ -581,12 +603,10 @@ class Mapper(object): if not no_sort: if self.order_by: order_by = self.order_by -# elif self.table.rowid_column is not None: - # order_by = self.table.rowid_column - # else: - # order_by = None - else: + elif self.table.rowid_column is not None: order_by = self.table.rowid_column + else: + order_by = None else: order_by = None @@ -779,6 +799,9 @@ def hash_key(obj): else: return repr(obj) +def is_column(col): + return isinstance(col, schema.Column) or isinstance(col, sql.ColumnElement) + def mapper_hash_key(class_, table, primarytable = None, properties = None, **kwargs): if properties is None: properties = {} diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index ba7312c12b..e53ee644cb 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -24,7 +24,6 @@ import sqlalchemy.util as util import sqlalchemy.attributes as attributes import mapper import objectstore -import random class ColumnProperty(MapperProperty): """describes an object attribute that corresponds to a table column.""" @@ -856,8 +855,7 @@ class Aliasizer(sql.ClauseVisitor): try: return self.aliases[table] except: - aliasname = table.name + "_" + hex(random.randint(0, 65535))[2:] - return self.aliases.setdefault(table, sql.alias(table, aliasname)) + return self.aliases.setdefault(table, sql.alias(table)) def visit_compound(self, compound): for i in range(0, len(compound.clauses)): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index b0e86259a1..7db60ffb95 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -20,7 +20,7 @@ import sqlalchemy.schema as schema import sqlalchemy.util as util import sqlalchemy.types as types -import string, re +import string, re, random __all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] @@ -497,7 +497,7 @@ class FromClause(Selectable): return Join(self, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): return Join(self, right, isouter = True, *args, **kwargs) - def alias(self, name): + def alias(self, name=None): return Alias(self, name) @@ -751,11 +751,17 @@ class Alias(FromClause): self._columns = util.OrderedProperties() self.foreign_keys = [] if alias is None: - alias = id(self) + n = getattr(selectable, 'name') + if n is None: + n = 'anon' + alias = n + "_" + hex(random.randint(0, 65535))[2:] self.name = alias self.id = self.name self.count = 0 - self.rowid_column = self.selectable.rowid_column._make_proxy(self) + if self.selectable.rowid_column is not None: + self.rowid_column = self.selectable.rowid_column._make_proxy(self) + else: + self.rowid_column = None for co in selectable.columns: co._make_proxy(self) @@ -930,7 +936,7 @@ class TableImpl(FromClause): return Join(self.table, right, *args, **kwargs) def outerjoin(self, right, *args, **kwargs): return Join(self.table, right, isouter = True, *args, **kwargs) - def alias(self, name): + def alias(self, name=None): return Alias(self.table, name) def select(self, whereclause = None, **params): return select([self.table], whereclause, **params) @@ -1082,16 +1088,20 @@ class Select(SelectBaseMixin, FromClause): for f in column._get_from_objects(): f.accept_visitor(self._correlator) - if self.rowid_column is None and hasattr(f, 'rowid_column') and f.rowid_column is not None: - self.rowid_column = f.rowid_column._make_proxy(self) column._process_from_dict(self._froms, False) if column.is_selectable(): + # if its a column unit, add it to our exported + # list of columns. this is where "columns" + # attribute of the select object gets populated. + # notice we are overriding the names of the column + # with either its label or its key, since one or the other + # is used when selecting from a select statement (i.e. a subquery) for co in column.columns: if self.use_labels: - co._make_proxy(self, name = co._label) + co._make_proxy(self, name=co._label) else: - co._make_proxy(self) + co._make_proxy(self, name=co.key) def _get_col_by_original(self, column): if self.use_labels: diff --git a/test/mapper.py b/test/mapper.py index 90d182b6a7..cc792109da 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -120,11 +120,13 @@ class MapperTest(MapperSuperTest): def testfunction(self): - s = select([users, (users.c.user_id * 2).label('concat'), func.count(users.c.user_id).label('count')], group_by=[c for c in users.c], use_labels=True) - m = mapper(User, s.alias('test')) + s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')], + users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]) + m = mapper(User, s, primarytable=users) + print [c.key for c in m.c] l = m.select() - print [repr(x.__dict__) for x in l] - + for u in l: + print "User", u.user_id, u.user_name, u.concat, u.count def testmultitable(self): usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) @@ -363,6 +365,7 @@ class LazyTest(MapperSuperTest): # use a union all to get a lot of rows to join against u2 = users.alias('u2') s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u') + print [key for key in s.c.keys()] l = m.select(s.c.u2_user_id==User.c.user_id, distinct=True) self.assert_result(l, User, *user_all_result) diff --git a/test/select.py b/test/select.py index cba3325788..1fa2fd456b 100644 --- a/test/select.py +++ b/test/select.py @@ -79,18 +79,19 @@ myothertable.othername FROM mytable, myothertable") #) s = select([table], table.c.name == 'jack') + print [key for key in s.c.keys()] self.runtest( select( [s], s.c.id == 7 ) , - "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid") + "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE id = :id") sq = select([table]) self.runtest( sq.select(), - "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable)" + "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable)" ) sq = subquery( @@ -100,8 +101,8 @@ myothertable.othername FROM mytable, myothertable") self.runtest( sq.select(sq.c.id == 7), - "SELECT sq.myid, sq.name, sq.description FROM \ -(SELECT mytable.myid, mytable.name, mytable.description FROM mytable) AS sq WHERE sq.myid = :sq_myid" + "SELECT sq.id, sq.name, sq.description FROM \ +(SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.id = :sq_id" ) sq = subquery( @@ -368,7 +369,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable def testcorrelatedsubquery(self): self.runtest( table.select(table.c.id == select([table2.c.id], table.c.name == table2.c.name)), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid FROM myothertable WHERE mytable.name = myothertable.othername)" + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid AS id FROM myothertable WHERE mytable.name = myothertable.othername)" ) self.runtest( @@ -380,19 +381,19 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable s = subquery('sq2', [talias], exists([1], table2.c.id == talias.c.id)) self.runtest( select([s, table]) - ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid, ta.name, ta.description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable") + ,"SELECT sq2.id, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS id, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable") s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s') self.runtest( select([users, s.c.street], from_obj=[s]), - """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") + """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""") def testin(self): self.runtest(select([table], table.c.id.in_(1, 2, 3)), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1, :mytable_myid_2)") self.runtest(select([table], table.c.id.in_(select([table2.c.id]))), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid FROM myothertable)") + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS id FROM myothertable)") def testlateargs(self): """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments