def update(cls, whereclause=None, values=None, **kwargs):
_ddl_error(cls)
- def _selectable(cls):
+ def __selectable__(cls):
return cls._table
def __getattr__(cls, attr):
return x
def class_for_table(selectable, **mapper_kwargs):
- if not hasattr(selectable, '_selectable') \
- or selectable._selectable() != selectable:
- raise ArgumentError('class_for_table requires a selectable as its argument')
+ selectable = sql._selectable(selectable)
mapname = 'Mapped' + _selectable_name(selectable)
if isinstance(selectable, Table):
klass = TableClassType(mapname, (object,), {})
def with_labels(self, item):
# TODO give meaningful aliases
- return self.map(item._selectable().select(use_labels=True).alias('foo'))
+ return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
if self._order_by is not False:
s1 = sql.select([col], self._criterion, **ops).alias('u')
- return sql.select([func(s1.corresponding_column(col))]).scalar()
+ return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar()
else:
- return sql.select([func(col)], self._criterion, **ops).scalar()
+ return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar()
def min(self, col):
"""Execute the SQL ``min()`` function against the given column."""
will attempt to provide similar functionality.
scalar=False
- when ``True``, indicates that the resulting ``Select`` object
- is to be used in the "columns" clause of another select statement,
- where the evaluated value of the column is the scalar result of
- this statement. Normally, placing any ``Selectable`` within the
- columns clause of a ``select()`` call will expand the member
- columns of the ``Selectable`` individually.
+ deprecated. use select(...).scalar() to create a "scalar column"
+ proxy for an existing Select object.
correlate=True
indicates that this ``Select`` object should have its contained
rendered in the ``FROM`` clause of this select statement.
"""
-
- return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+ scalar = kwargs.pop('scalar', False)
+ s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+ if scalar:
+ return s.scalar()
+ else:
+ return s
def subquery(alias, *args, **kwargs):
"""Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
else:
return element
+
+def _selectable(element):
+ if hasattr(element, '__selectable__'):
+ return element.__selectable__()
+ elif isinstance(element, Selectable):
+ return element
+ else:
+ raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
def is_column(col):
return isinstance(col, ColumnElement)
if _is_literal(o) or isinstance( o, _CompareMixin):
return self.__eq__( o) #single item -> ==
else:
- assert hasattr( o, '_selectable') #better check?
+ assert isinstance(o, Selectable)
return self.__compare( op, o, negate=negate_op) #single selectable
args = []
columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""")
- def _selectable(self):
- return self
-
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
-
class ColumnElement(Selectable, _CompareMixin):
"""Represent an element that is useable within the
"column clause" portion of a ``SELECT`` statement.
"""return the list of ColumnElements represented within this FromClause's _exportable_columns"""
export = self._exportable_columns()
for column in export:
- if hasattr(column, '_selectable'):
- s = column._selectable()
+ # TODO: is this conditional needed ?
+ if isinstance(column, Selectable):
+ s = column
else:
continue
for co in s.columns:
return select([self])
def scalar(self):
- return select([self]).scalar()
+ return select([self]).execute().scalar()
def execute(self):
return select([self]).execute()
"""
def __init__(self, left, right, onclause=None, isouter = False):
- self.left = left._selectable()
- self.right = right._selectable().self_group()
+ self.left = _selectable(left)
+ self.right = _selectable(right).self_group()
if onclause is None:
self.onclause = self._match_primaries(self.left, self.right)
else:
bind = property(lambda s: s.selectable.bind)
-class _Grouping(ColumnElement):
+class _ColumnElementAdapter(ColumnElement):
+ """adapts a ClauseElement which may or may not be a
+ ColumnElement subclass itself into an object which
+ acts like a ColumnElement.
+ """
+
def __init__(self, elem):
self.elem = elem
self.type = getattr(elem, 'type', None)
+ self.orig_set = getattr(elem, 'orig_set', util.Set())
-
key = property(lambda s: s.elem.key)
_label = property(lambda s: s.elem._label)
- orig_set = property(lambda s:s.elem.orig_set)
columns = c = property(lambda s:s.elem.columns)
-
+
def _copy_internals(self):
- print "GROPING COPY INTERNALS"
self.elem = self.elem._clone()
- print "NEW ID", id(self.elem)
-
+
def get_children(self, **kwargs):
return self.elem,
-
+
def _hide_froms(self, **modifiers):
return self.elem._hide_froms(**modifiers)
-
+
def _get_from_objects(self, **modifiers):
return self.elem._get_from_objects(**modifiers)
def __getattr__(self, attr):
return getattr(self.elem, attr)
+class _Grouping(_ColumnElementAdapter):
+ pass
+
class _Label(ColumnElement):
"""represent a label, as typically applied to any column-level element
using the ``AS`` sql keyword.
def _get_from_objects(self, **modifiers):
return [self]
+
class _SelectBaseMixin(object):
"""Base class for ``Select`` and ``CompoundSelects``."""
- def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, scalar=False):
+ def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None):
self.use_labels = use_labels
self.for_update = for_update
self._limit = limit
self._offset = offset
self._bind = bind
- self.is_scalar = scalar
- if self.is_scalar:
- # allow corresponding_column to return None
- self.orig_set = util.Set()
self.append_order_by(*util.to_list(order_by, []))
self.append_group_by(*util.to_list(group_by, []))
+
+ def scalar(self):
+ return _ScalarSelect(self)
+
+ def label(self, name):
+ return self.scalar().label(name)
def supports_execution(self):
return True
return select([self], whereclauses, **params)
def _get_from_objects(self, is_where=False, **modifiers):
- if is_where or self.is_scalar:
+ if is_where:
return []
else:
return [self]
+class _ScalarSelect(_Grouping):
+ __visit_name__ = 'grouping'
+
+ def __init__(self, elem):
+ super(_ScalarSelect, self).__init__(elem)
+ self.type = list(elem.inner_columns)[0].type
+
+ columns = property(lambda self:[self])
+
+ def self_group(self, **kwargs):
+ return self
+
+ def _make_proxy(self, selectable, name):
+ return list(self.inner_columns)[0]._make_proxy(selectable, name)
+
+ def _get_from_objects(self, **modifiers):
+ return []
+
class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
self._should_correlate = kwargs.pop('correlate', False)
def _get_inner_columns(self):
for c in self._raw_columns:
- # TODO: need to have Select, as well as a Select inside a _Grouping,
- # give us a clearer idea of if we want its column list or not
- if hasattr(c, '_selectable') and not getattr(c, 'is_scalar', False):
- for co in c._selectable().columns:
+ if isinstance(c, Selectable):
+ for co in c.columns:
yield co
else:
yield c
if _is_literal(column):
column = literal_column(str(column))
- if isinstance(column, Select) and column.is_scalar:
+ if isinstance(column, _ScalarSelect):
column = column.self_group(against=ColumnOperators.comma_op)
self._raw_columns.append(column)
fromclause = FromClause(fromclause)
self._froms.add(fromclause)
- def _make_proxy(self, selectable, name):
- if self.is_scalar:
- return list(self.inner_columns)[0]._make_proxy(selectable, name)
- else:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
-
- def label(self, name):
- if not self.is_scalar:
- raise exceptions.InvalidRequestError("Not a scalar select statement")
- else:
- return label(name, self)
-
- def _get_type(self):
- if self.is_scalar:
- return list(self.inner_columns)[0].type
- else:
- return None
- type = property(_get_type)
-
def _exportable_columns(self):
return [c for c in self._raw_columns if isinstance(c, Selectable)]
mapper(Foo, foo)
metadata.create_all()
- sess = create_session()
+ sess = create_session(bind=testbase.db)
for i in range(100):
sess.save(Foo(bar=i, range=i%10))
sess.flush()
clear_mappers()
def test_selectby(self):
- res = create_session().query(Foo).filter_by(range=5)
+ res = create_session(bind=testbase.db).query(Foo).filter_by(range=5)
assert res.order_by([Foo.c.bar])[0].bar == 5
assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
@testing.unsupported('mssql')
def test_slice(self):
- sess = create_session()
+ sess = create_session(bind=testbase.db)
query = sess.query(Foo)
orig = query.all()
assert query[1] == orig[1]
@testing.supported('mssql')
def test_slice_mssql(self):
- sess = create_session()
+ sess = create_session(bind=testbase.db)
query = sess.query(Foo)
orig = query.all()
assert list(query[:10]) == orig[:10]
assert list(query[:10]) == orig[:10]
def test_aggregate(self):
- sess = create_session()
+ sess = create_session(bind=testbase.db)
query = sess.query(Foo)
assert query.count() == 100
assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
@testing.unsupported('mysql')
def test_aggregate_1(self):
# this one fails in mysql as the result comes back as a string
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
@testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_2(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
@testing.supported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_2_int(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
@testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
def test_aggregate_3(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5
assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5
def test_filter(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert query.count() == 100
assert query.filter(Foo.c.bar < 30).count() == 30
res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
assert res2.count() == 19
def test_options(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
class ext1(MapperExtension):
def populate_instance(self, mapper, selectcontext, row, instance, **flags):
instance.TEST = "hello world"
assert query.options(extension(ext1()))[0].TEST == "hello world"
def test_order_by(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert query.order_by([Foo.c.bar])[0].bar == 0
assert query.order_by([desc(Foo.c.bar)])[0].bar == 99
def test_offset(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10
def test_offset(self):
- query = create_session().query(Foo)
+ query = create_session(bind=testbase.db).query(Foo)
assert len(list(query.limit(10))) == 10
class Obj1(object):
class GenerativeTest2(PersistTest):
def setUpAll(self):
global metadata, table1, table2
- metadata = MetaData(testbase.db)
+ metadata = MetaData()
table1 = Table('Table1', metadata,
Column('id', Integer, primary_key=True),
)
)
mapper(Obj1, table1)
mapper(Obj2, table2)
- metadata.create_all()
- table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
- table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
+ metadata.create_all(bind=testbase.db)
+ testbase.db.execute(table1.insert(), {'id':1},{'id':2},{'id':3},{'id':4})
+ testbase.db.execute(table2.insert(), {'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
def tearDownAll(self):
- metadata.drop_all()
+ metadata.drop_all(bind=testbase.db)
clear_mappers()
def test_distinctcount(self):
- query = create_session().query(Obj1)
+ query = create_session(bind=testbase.db).query(Obj1)
assert query.count() == 4
res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
assert res.count() == 3
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2)
print x.compile()
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
print x.compile()
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
assert x==2
'items':relation(mapper(tables.Item, tables.orderitems))
}))
})
- session = create_session()
+ session = create_session(bind=testbase.db)
query = session.query(tables.User)
x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\
filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
clear_mappers()
def test_distinctcount(self):
- q = create_session().query(Obj1)
+ q = create_session(bind=testbase.db).query(Obj1)
assert q.count() == 4
res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
assert res.count() == 3
def test_noautojoin(self):
class T(object):pass
mapper(T, t1, properties={'children':relation(T)})
- sess = create_session()
+ sess = create_session(bind=testbase.db)
try:
sess.query(T).join('children').select_by(id=7)
assert False
mapper(User, users, properties={
'concat': column_property(f),
- 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id, scalar=True).correlate(users).label('count'))
+ 'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).label('count'))
})
mapper(Address, addresses, properties={
x = testbase.db.func.current_date().execute().scalar()
y = testbase.db.func.current_date().select().execute().scalar()
z = testbase.db.func.current_date().scalar()
- assert x == y == z
+ assert (x == y == z) is True
x = testbase.db.func.current_date(type_=Date)
assert isinstance(x.type, Date)
z = conn.scalar(func.current_date())
finally:
conn.close()
- assert x == y == z
-
+ assert (x == y == z) is True
+
def test_update_functions(self):
"""test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances,
and that column-level defaults get overridden"""
s = select([table1.c.myid], scalar=True)
self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
-
+
+ s = select([table1.c.myid]).correlate(None).scalar()
+ self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+
+ s = select([table1.c.myid]).scalar()
+ self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+
+ # test expressions against scalar selects
+ self.runtest(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal")
+ self.runtest(select([select([table1.c.name]).scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal")
+ self.runtest(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal")
+
+ self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
+
zips = table('zips',
column('zipcode'),
column('nm')
)
zip = '12345'
- qlat = select([zips.c.latitude], zips.c.zipcode == zip, scalar=True, correlate=False)
- qlng = select([zips.c.longitude], zips.c.zipcode == zip, scalar=True, correlate=False)
+ qlat = select([zips.c.latitude], zips.c.zipcode == zip).correlate(None).scalar()
+ qlng = select([zips.c.longitude], zips.c.zipcode == zip).correlate(None).scalar()
q = select([places.c.id, places.c.nm, zips.c.zipcode, func.latlondist(qlat, qlng).label('dist')],
zips.c.zipcode==zip,