-from sqlalchemy import func, Integer, String, ForeignKey
+from sqlalchemy import func, Integer, Numeric, String, ForeignKey
from sqlalchemy.orm import relationship, Session, aliased
from sqlalchemy.testing.schema import Column
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext import hybrid
-from sqlalchemy.testing import eq_, AssertsCompiledSQL, assert_raises_message
+from sqlalchemy.testing import eq_, is_, AssertsCompiledSQL, \
+ assert_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy import inspect
+from decimal import Decimal
class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
@hybrid.hybrid_property
def value(self):
+ "This is a docstring"
return self._value - 5
@value.comparator
"FROM a AS a_1 WHERE upper(a_1.value) = upper(:upper_1)"
)
+ def test_docstring(self):
+ A = self._fixture()
+ eq_(A.value.__doc__, "This is a docstring")
+
+
class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'
+
def _fixture(self):
Base = declarative_base()
@hybrid.hybrid_property
def value(self):
+ "This is an instance-level docstring"
return int(self._value) - 5
@value.expression
def value(cls):
+ "This is a class-level docstring"
return func.foo(cls._value) + cls.bar_value
@value.setter
def test_expression(self):
A = self._fixture()
self.assert_compile(
- A.value,
+ A.value.__clause_element__(),
"foo(a.value) + bar(a.value)"
)
def test_aliased_expression(self):
A = self._fixture()
self.assert_compile(
- aliased(A).value,
+ aliased(A).value.__clause_element__(),
"foo(a_1.value) + bar(a_1.value)"
)
"FROM a AS a_1 WHERE foo(a_1.value) + bar(a_1.value) = :param_1"
)
+ def test_docstring(self):
+ A = self._fixture()
+ eq_(A.value.__doc__, "This is a class-level docstring")
+
+ # no docstring here since we get a literal
+ a1 = A(_value=10)
+ eq_(a1.value, 5)
+
+
class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'
+
def _fixture(self, assignable):
Base = declarative_base()
delattr, a1, 'value'
)
-
def test_set_get(self):
A = self._fixture(True)
a1 = A(value=5)
eq_(a1.value, 5)
eq_(a1._value, 10)
+
class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'
+
def _fixture(self):
Base = declarative_base()
@hybrid.hybrid_method
def value(self, x):
+ "This is an instance-level docstring"
return int(self._value) + x
@value.expression
def value(cls, value):
+ "This is a class-level docstring"
+ return func.foo(cls._value, value) + value
+
+ @hybrid.hybrid_method
+ def other_value(self, x):
+ "This is an instance-level docstring"
+ return int(self._value) + x
+
+ @other_value.expression
+ def other_value(cls, value):
return func.foo(cls._value, value) + value
return A
{"some key": "some value"}
)
-
def test_aliased_expression(self):
A = self._fixture()
self.assert_compile(
sess.query(aliased(A).value(5)),
"SELECT foo(a_1.value, :foo_1) + :foo_2 AS anon_1 FROM a AS a_1"
)
+
+ def test_docstring(self):
+ A = self._fixture()
+ eq_(A.value.__doc__, "This is a class-level docstring")
+ eq_(A.other_value.__doc__, "This is an instance-level docstring")
+ a1 = A(_value=10)
+
+ # a1.value is still a method, so it has a
+ # docstring
+ eq_(a1.value.__doc__, "This is an instance-level docstring")
+
+ eq_(a1.other_value.__doc__, "This is an instance-level docstring")
+
+
+class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
+ """tests against hybrids that return a non-ClauseElement.
+
+ use cases derived from the example at
+ http://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/
+
+ """
+ __dialect__ = 'default'
+
+ @classmethod
+ def setup_class(cls):
+ from sqlalchemy import literal
+
+ symbols = ('usd', 'gbp', 'cad', 'eur', 'aud')
+ currency_lookup = dict(
+ ((currency_from, currency_to), Decimal(str(rate)))
+ for currency_to, values in zip(
+ symbols,
+ [
+ (1, 1.59009, 0.988611, 1.37979, 1.02962),
+ (0.628895, 1, 0.621732, 0.867748, 0.647525),
+ (1.01152, 1.6084, 1, 1.39569, 1.04148),
+ (0.724743, 1.1524, 0.716489, 1, 0.746213),
+ (0.971228, 1.54434, 0.960166, 1.34009, 1),
+ ])
+ for currency_from, rate in zip(symbols, values)
+ )
+
+ class Amount(object):
+ def __init__(self, amount, currency):
+ self.currency = currency
+ self.amount = amount
+
+ def __add__(self, other):
+ return Amount(
+ self.amount +
+ other.as_currency(self.currency).amount,
+ self.currency
+ )
+
+ def __sub__(self, other):
+ return Amount(
+ self.amount -
+ other.as_currency(self.currency).amount,
+ self.currency
+ )
+
+ def __lt__(self, other):
+ return self.amount < other.as_currency(self.currency).amount
+
+ def __gt__(self, other):
+ return self.amount > other.as_currency(self.currency).amount
+
+ def __eq__(self, other):
+ return self.amount == other.as_currency(self.currency).amount
+
+ def as_currency(self, other_currency):
+ return Amount(
+ currency_lookup[(self.currency, other_currency)] *
+ self.amount,
+ other_currency
+ )
+
+ def __clause_element__(self):
+ # helper method for SQLAlchemy to interpret
+ # the Amount object as a SQL element
+ if isinstance(self.amount, (float, int, Decimal)):
+ return literal(self.amount)
+ else:
+ return self.amount
+
+ def __str__(self):
+ return "%2.4f %s" % (self.amount, self.currency)
+
+ def __repr__(self):
+ return "Amount(%r, %r)" % (self.amount, self.currency)
+
+ Base = declarative_base()
+
+ class BankAccount(Base):
+ __tablename__ = 'bank_account'
+ id = Column(Integer, primary_key=True)
+
+ _balance = Column('balance', Numeric)
+
+ @hybrid.hybrid_property
+ def balance(self):
+ """Return an Amount view of the current balance."""
+ return Amount(self._balance, "usd")
+
+ @balance.setter
+ def balance(self, value):
+ self._balance = value.as_currency("usd").amount
+
+ cls.Amount = Amount
+ cls.BankAccount = BankAccount
+
+ def test_instance_one(self):
+ BankAccount, Amount = self.BankAccount, self.Amount
+ account = BankAccount(balance=Amount(4000, "usd"))
+
+ # 3b. print balance in usd
+ eq_(account.balance.amount, 4000)
+
+ def test_instance_two(self):
+ BankAccount, Amount = self.BankAccount, self.Amount
+ account = BankAccount(balance=Amount(4000, "usd"))
+
+ # 3c. print balance in gbp
+ eq_(account.balance.as_currency("gbp").amount, Decimal('2515.58'))
+
+ def test_instance_three(self):
+ BankAccount, Amount = self.BankAccount, self.Amount
+ account = BankAccount(balance=Amount(4000, "usd"))
+
+ # 3d. perform currency-agnostic comparisons, math
+ is_(account.balance > Amount(500, "cad"), True)
+
+ def test_instance_four(self):
+ BankAccount, Amount = self.BankAccount, self.Amount
+ account = BankAccount(balance=Amount(4000, "usd"))
+ eq_(
+ account.balance + Amount(500, "cad") - Amount(50, "eur"),
+ Amount(Decimal("4425.316"), "usd")
+ )
+
+ def test_query_one(self):
+ BankAccount, Amount = self.BankAccount, self.Amount
+ session = Session()
+
+ query = session.query(BankAccount).\
+ filter(BankAccount.balance == Amount(10000, "cad"))
+
+ self.assert_compile(
+ query,
+ "SELECT bank_account.balance AS bank_account_balance, "
+ "bank_account.id AS bank_account_id FROM bank_account "
+ "WHERE bank_account.balance = :balance_1",
+ checkparams={'balance_1': Decimal('9886.110000')}
+ )
+
+ def test_query_two(self):
+ BankAccount, Amount = self.BankAccount, self.Amount
+ session = Session()
+
+ # alternatively we can do the calc on the DB side.
+ query = session.query(BankAccount).\
+ filter(
+ BankAccount.balance.as_currency("cad") > Amount(9999, "cad")).\
+ filter(
+ BankAccount.balance.as_currency("cad") < Amount(10001, "cad"))
+ self.assert_compile(
+ query,
+ "SELECT bank_account.balance AS bank_account_balance, "
+ "bank_account.id AS bank_account_id "
+ "FROM bank_account "
+ "WHERE :balance_1 * bank_account.balance > :param_1 "
+ "AND :balance_2 * bank_account.balance < :param_2",
+ checkparams={
+ 'balance_1': Decimal('1.01152'),
+ 'balance_2': Decimal('1.01152'),
+ 'param_1': Decimal('9999'),
+ 'param_2': Decimal('10001')}
+ )
+
+ def test_query_three(self):
+ BankAccount = self.BankAccount
+ session = Session()
+
+ query = session.query(BankAccount).\
+ filter(
+ BankAccount.balance.as_currency("cad") >
+ BankAccount.balance.as_currency("eur"))
+ self.assert_compile(
+ query,
+ "SELECT bank_account.balance AS bank_account_balance, "
+ "bank_account.id AS bank_account_id FROM bank_account "
+ "WHERE :balance_1 * bank_account.balance > "
+ ":param_1 * :balance_2 * bank_account.balance",
+ checkparams={
+ 'balance_1': Decimal('1.01152'),
+ 'balance_2': Decimal('0.724743'),
+ 'param_1': Decimal('1.39569')}
+ )
+
+ def test_query_four(self):
+ BankAccount = self.BankAccount
+ session = Session()
+
+ # 4c. query all amounts, converting to "CAD" on the DB side
+ query = session.query(BankAccount.balance.as_currency("cad").amount)
+ self.assert_compile(
+ query,
+ "SELECT :balance_1 * bank_account.balance AS anon_1 "
+ "FROM bank_account",
+ checkparams={'balance_1': Decimal('1.01152')}
+ )
+
+ def test_query_five(self):
+ BankAccount = self.BankAccount
+ session = Session()
+
+ # 4d. average balance in EUR
+ query = session.query(func.avg(BankAccount.balance.as_currency("eur")))
+ self.assert_compile(
+ query,
+ "SELECT avg(:balance_1 * bank_account.balance) AS avg_1 "
+ "FROM bank_account",
+ checkparams={'balance_1': Decimal('0.724743')}
+ )
+
+ def test_docstring(self):
+ BankAccount = self.BankAccount
+ eq_(
+ BankAccount.balance.__doc__,
+ "Return an Amount view of the current balance.")
+