From: Mike Bayer Date: Tue, 24 Jul 2007 20:05:10 +0000 (+0000) Subject: - deprecated scalar=True argument on select(). its replaced X-Git-Tag: rel_0_4_6~32 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0f7e7c3a6e4aad83cb404c094a05b4eff2a19e31;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - deprecated scalar=True argument on select(). its replaced by select().scalar() which returns a _ScalarSelect object, that obeys the ColumnElement interface fully - removed _selectable() method. replaced with __selectable__() as an optional duck-typer; subclassing Selectable (without any __selectable__()) is equivalent - query._col_aggregate() was assuming bound metadata. ick ! - probably should deprecate ClauseElement.scalar(), in favor of ClauseElement.execute().scalar()... otherwise might need to rename select().scalar() --- diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 3cffdd098b..3ff8f3ee70 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -392,7 +392,7 @@ class SelectableClassType(type): def update(cls, whereclause=None, values=None, **kwargs): _ddl_error(cls) - def _selectable(cls): + def __selectable__(cls): return cls._table def __getattr__(cls, attr): @@ -434,9 +434,7 @@ def _selectable_name(selectable): return x def class_for_table(selectable, **mapper_kwargs): - if not hasattr(selectable, '_selectable') \ - or selectable._selectable() != selectable: - raise ArgumentError('class_for_table requires a selectable as its argument') + selectable = sql._selectable(selectable) mapname = 'Mapped' + _selectable_name(selectable) if isinstance(selectable, Table): klass = TableClassType(mapname, (object,), {}) @@ -520,7 +518,7 @@ class SqlSoup: def with_labels(self, item): # TODO give meaningful aliases - return self.map(item._selectable().select(use_labels=True).alias('foo')) + return self.map(sql._selectable(item).select(use_labels=True).alias('foo')) def join(self, *args, **kwargs): j = join(*args, **kwargs) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 68a71e5655..da5bc753fb 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -400,9 +400,9 @@ class Query(object): if self._order_by is not False: s1 = sql.select([col], self._criterion, **ops).alias('u') - return sql.select([func(s1.corresponding_column(col))]).scalar() + return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar() else: - return sql.select([func(col)], self._criterion, **ops).scalar() + return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar() def min(self, col): """Execute the SQL ``min()`` function against the given column.""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c463e1e995..351ac0d1ed 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -202,12 +202,8 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): will attempt to provide similar functionality. scalar=False - when ``True``, indicates that the resulting ``Select`` object - is to be used in the "columns" clause of another select statement, - where the evaluated value of the column is the scalar result of - this statement. Normally, placing any ``Selectable`` within the - columns clause of a ``select()`` call will expand the member - columns of the ``Selectable`` individually. + deprecated. use select(...).scalar() to create a "scalar column" + proxy for an existing Select object. correlate=True indicates that this ``Select`` object should have its contained @@ -218,8 +214,12 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): rendered in the ``FROM`` clause of this select statement. """ - - return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + scalar = kwargs.pop('scalar', False) + s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs) + if scalar: + return s.scalar() + else: + return s def subquery(alias, *args, **kwargs): """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select]. @@ -762,6 +762,14 @@ def _literal_as_binds(element, name='literal', type_=None): return _BindParamClause(name, element, shortname=name, type_=type_, unique=True) else: return element + +def _selectable(element): + if hasattr(element, '__selectable__'): + return element.__selectable__() + elif isinstance(element, Selectable): + return element + else: + raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) def is_column(col): return isinstance(col, ColumnElement) @@ -1348,7 +1356,7 @@ class _CompareMixin(ColumnOperators): if _is_literal(o) or isinstance( o, _CompareMixin): return self.__eq__( o) #single item -> == else: - assert hasattr( o, '_selectable') #better check? + assert isinstance(o, Selectable) return self.__compare( op, o, negate=negate_op) #single selectable args = [] @@ -1446,14 +1454,10 @@ class Selectable(ClauseElement): columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""") - def _selectable(self): - return self - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) - class ColumnElement(Selectable, _CompareMixin): """Represent an element that is useable within the "column clause" portion of a ``SELECT`` statement. @@ -1806,8 +1810,9 @@ class FromClause(Selectable): """return the list of ColumnElements represented within this FromClause's _exportable_columns""" export = self._exportable_columns() for column in export: - if hasattr(column, '_selectable'): - s = column._selectable() + # TODO: is this conditional needed ? + if isinstance(column, Selectable): + s = column else: continue for co in s.columns: @@ -2081,7 +2086,7 @@ class _CalculatedClause(ColumnElement): return select([self]) def scalar(self): - return select([self]).scalar() + return select([self]).execute().scalar() def execute(self): return select([self]).execute() @@ -2254,8 +2259,8 @@ class Join(FromClause): """ def __init__(self, left, right, onclause=None, isouter = False): - self.left = left._selectable() - self.right = right._selectable().self_group() + self.left = _selectable(left) + self.right = _selectable(right).self_group() if onclause is None: self.onclause = self._match_primaries(self.left, self.right) else: @@ -2501,34 +2506,39 @@ class Alias(FromClause): bind = property(lambda s: s.selectable.bind) -class _Grouping(ColumnElement): +class _ColumnElementAdapter(ColumnElement): + """adapts a ClauseElement which may or may not be a + ColumnElement subclass itself into an object which + acts like a ColumnElement. + """ + def __init__(self, elem): self.elem = elem self.type = getattr(elem, 'type', None) + self.orig_set = getattr(elem, 'orig_set', util.Set()) - key = property(lambda s: s.elem.key) _label = property(lambda s: s.elem._label) - orig_set = property(lambda s:s.elem.orig_set) columns = c = property(lambda s:s.elem.columns) - + def _copy_internals(self): - print "GROPING COPY INTERNALS" self.elem = self.elem._clone() - print "NEW ID", id(self.elem) - + def get_children(self, **kwargs): return self.elem, - + def _hide_froms(self, **modifiers): return self.elem._hide_froms(**modifiers) - + def _get_from_objects(self, **modifiers): return self.elem._get_from_objects(**modifiers) def __getattr__(self, attr): return getattr(self.elem, attr) +class _Grouping(_ColumnElementAdapter): + pass + class _Label(ColumnElement): """represent a label, as typically applied to any column-level element using the ``AS`` sql keyword. @@ -2764,22 +2774,25 @@ class TableClause(FromClause): def _get_from_objects(self, **modifiers): return [self] + class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" - def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, scalar=False): + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None): self.use_labels = use_labels self.for_update = for_update self._limit = limit self._offset = offset self._bind = bind - self.is_scalar = scalar - if self.is_scalar: - # allow corresponding_column to return None - self.orig_set = util.Set() self.append_order_by(*util.to_list(order_by, [])) self.append_group_by(*util.to_list(group_by, [])) + + def scalar(self): + return _ScalarSelect(self) + + def label(self, name): + return self.scalar().label(name) def supports_execution(self): return True @@ -2829,11 +2842,29 @@ class _SelectBaseMixin(object): return select([self], whereclauses, **params) def _get_from_objects(self, is_where=False, **modifiers): - if is_where or self.is_scalar: + if is_where: return [] else: return [self] +class _ScalarSelect(_Grouping): + __visit_name__ = 'grouping' + + def __init__(self, elem): + super(_ScalarSelect, self).__init__(elem) + self.type = list(elem.inner_columns)[0].type + + columns = property(lambda self:[self]) + + def self_group(self, **kwargs): + return self + + def _make_proxy(self, selectable, name): + return list(self.inner_columns)[0]._make_proxy(selectable, name) + + def _get_from_objects(self, **modifiers): + return [] + class CompoundSelect(_SelectBaseMixin, FromClause): def __init__(self, keyword, *selects, **kwargs): self._should_correlate = kwargs.pop('correlate', False) @@ -3077,10 +3108,8 @@ class Select(_SelectBaseMixin, FromClause): def _get_inner_columns(self): for c in self._raw_columns: - # TODO: need to have Select, as well as a Select inside a _Grouping, - # give us a clearer idea of if we want its column list or not - if hasattr(c, '_selectable') and not getattr(c, 'is_scalar', False): - for co in c._selectable().columns: + if isinstance(c, Selectable): + for co in c.columns: yield co else: yield c @@ -3160,7 +3189,7 @@ class Select(_SelectBaseMixin, FromClause): if _is_literal(column): column = literal_column(str(column)) - if isinstance(column, Select) and column.is_scalar: + if isinstance(column, _ScalarSelect): column = column.self_group(against=ColumnOperators.comma_op) self._raw_columns.append(column) @@ -3182,25 +3211,6 @@ class Select(_SelectBaseMixin, FromClause): fromclause = FromClause(fromclause) self._froms.add(fromclause) - def _make_proxy(self, selectable, name): - if self.is_scalar: - return list(self.inner_columns)[0]._make_proxy(selectable, name) - else: - raise exceptions.InvalidRequestError("Not a scalar select statement") - - def label(self, name): - if not self.is_scalar: - raise exceptions.InvalidRequestError("Not a scalar select statement") - else: - return label(name, self) - - def _get_type(self): - if self.is_scalar: - return list(self.inner_columns)[0].type - else: - return None - type = property(_get_type) - def _exportable_columns(self): return [c for c in self._raw_columns if isinstance(c, Selectable)] diff --git a/test/orm/generative.py b/test/orm/generative.py index 5106388d73..4a90c13cb1 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -24,7 +24,7 @@ class GenerativeQueryTest(PersistTest): mapper(Foo, foo) metadata.create_all() - sess = create_session() + sess = create_session(bind=testbase.db) for i in range(100): sess.save(Foo(bar=i, range=i%10)) sess.flush() @@ -34,13 +34,13 @@ class GenerativeQueryTest(PersistTest): clear_mappers() def test_selectby(self): - res = create_session().query(Foo).filter_by(range=5) + res = create_session(bind=testbase.db).query(Foo).filter_by(range=5) assert res.order_by([Foo.c.bar])[0].bar == 5 assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 @testing.unsupported('mssql') def test_slice(self): - sess = create_session() + sess = create_session(bind=testbase.db) query = sess.query(Foo) orig = query.all() assert query[1] == orig[1] @@ -54,14 +54,14 @@ class GenerativeQueryTest(PersistTest): @testing.supported('mssql') def test_slice_mssql(self): - sess = create_session() + sess = create_session(bind=testbase.db) query = sess.query(Foo) orig = query.all() assert list(query[:10]) == orig[:10] assert list(query[:10]) == orig[:10] def test_aggregate(self): - sess = create_session() + sess = create_session(bind=testbase.db) query = sess.query(Foo) assert query.count() == 100 assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0 @@ -72,34 +72,34 @@ class GenerativeQueryTest(PersistTest): @testing.unsupported('mysql') def test_aggregate_1(self): # this one fails in mysql as the result comes back as a string - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435 @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_2(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 @testing.supported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_2_int(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14 @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_3(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5 assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5 def test_filter(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert query.count() == 100 assert query.filter(Foo.c.bar < 30).count() == 30 res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10) assert res2.count() == 19 def test_options(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) class ext1(MapperExtension): def populate_instance(self, mapper, selectcontext, row, instance, **flags): instance.TEST = "hello world" @@ -107,16 +107,16 @@ class GenerativeQueryTest(PersistTest): assert query.options(extension(ext1()))[0].TEST == "hello world" def test_order_by(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert query.order_by([Foo.c.bar])[0].bar == 0 assert query.order_by([desc(Foo.c.bar)])[0].bar == 99 def test_offset(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10 def test_offset(self): - query = create_session().query(Foo) + query = create_session(bind=testbase.db).query(Foo) assert len(list(query.limit(10))) == 10 class Obj1(object): @@ -127,7 +127,7 @@ class Obj2(object): class GenerativeTest2(PersistTest): def setUpAll(self): global metadata, table1, table2 - metadata = MetaData(testbase.db) + metadata = MetaData() table1 = Table('Table1', metadata, Column('id', Integer, primary_key=True), ) @@ -137,17 +137,17 @@ class GenerativeTest2(PersistTest): ) mapper(Obj1, table1) mapper(Obj2, table2) - metadata.create_all() - table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4}) - table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\ + metadata.create_all(bind=testbase.db) + testbase.db.execute(table1.insert(), {'id':1},{'id':2},{'id':3},{'id':4}) + testbase.db.execute(table2.insert(), {'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\ {'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3}) def tearDownAll(self): - metadata.drop_all() + metadata.drop_all(bind=testbase.db) clear_mappers() def test_distinctcount(self): - query = create_session().query(Obj1) + query = create_session(bind=testbase.db).query(Obj1) assert query.count() == 4 res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)) assert res.count() == 3 @@ -169,7 +169,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session() + session = create_session(bind=testbase.db) query = session.query(tables.User) x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2) print x.compile() @@ -181,7 +181,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session() + session = create_session(bind=testbase.db) query = session.query(tables.User) x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) print x.compile() @@ -193,7 +193,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session() + session = create_session(bind=testbase.db) query = session.query(tables.User) x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count() assert x==2 @@ -203,7 +203,7 @@ class RelationsTest(AssertMixin): 'items':relation(mapper(tables.Item, tables.orderitems)) })) }) - session = create_session() + session = create_session(bind=testbase.db) query = session.query(tables.User) x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\ filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) @@ -234,7 +234,7 @@ class CaseSensitiveTest(PersistTest): clear_mappers() def test_distinctcount(self): - q = create_session().query(Obj1) + q = create_session(bind=testbase.db).query(Obj1) assert q.count() == 4 res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)) assert res.count() == 3 @@ -251,7 +251,7 @@ class SelfRefTest(ORMTest): def test_noautojoin(self): class T(object):pass mapper(T, t1, properties={'children':relation(T)}) - sess = create_session() + sess = create_session(bind=testbase.db) try: sess.query(T).join('children').select_by(id=7) assert False diff --git a/test/orm/query.py b/test/orm/query.py index 9f14d9a25a..d4a5120eaf 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -773,7 +773,7 @@ class ExternalColumnsTest(QueryTest): mapper(User, users, properties={ 'concat': column_property(f), - 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id, scalar=True).correlate(users).label('count')) + 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).label('count')) }) mapper(Address, addresses, properties={ diff --git a/test/sql/query.py b/test/sql/query.py index 61734d5d92..11b7beffb2 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -274,7 +274,7 @@ class QueryTest(PersistTest): x = testbase.db.func.current_date().execute().scalar() y = testbase.db.func.current_date().select().execute().scalar() z = testbase.db.func.current_date().scalar() - assert x == y == z + assert (x == y == z) is True x = testbase.db.func.current_date(type_=Date) assert isinstance(x.type, Date) @@ -288,8 +288,8 @@ class QueryTest(PersistTest): z = conn.scalar(func.current_date()) finally: conn.close() - assert x == y == z - + assert (x == y == z) is True + def test_update_functions(self): """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances, and that column-level defaults get overridden""" diff --git a/test/sql/select.py b/test/sql/select.py index 18709f55d9..ad5df2e06c 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -199,7 +199,20 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A s = select([table1.c.myid], scalar=True) self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") - + + s = select([table1.c.myid]).correlate(None).scalar() + self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable") + + s = select([table1.c.myid]).scalar() + self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable") + + # test expressions against scalar selects + self.runtest(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal") + self.runtest(select([select([table1.c.name]).scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal") + self.runtest(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal") + + self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo") + zips = table('zips', column('zipcode'), @@ -211,8 +224,8 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A column('nm') ) zip = '12345' - qlat = select([zips.c.latitude], zips.c.zipcode == zip, scalar=True, correlate=False) - qlng = select([zips.c.longitude], zips.c.zipcode == zip, scalar=True, correlate=False) + qlat = select([zips.c.latitude], zips.c.zipcode == zip).correlate(None).scalar() + qlng = select([zips.c.longitude], zips.c.zipcode == zip).correlate(None).scalar() q = select([places.c.id, places.c.nm, zips.c.zipcode, func.latlondist(qlat, qlng).label('dist')], zips.c.zipcode==zip,