CHANGES
=======
+0.5.8
+=====
+- postgresql
+ - The extract() function, which was slightly improved in
+ 0.5.7, needed a lot more work to generate the correct
+ typecast (the typecasts appear to be necessary in PG's
+ EXTRACT quite a lot of the time). The typecast is
+ now generated using a rule dictionary based
+ on PG's documentation for date/time/interval arithmetic.
+ It also accepts text() constructs again, which was broken
+ in 0.5.7. [ticket:1647]
+
0.5.7
=====
- orm
from sqlalchemy import sql, schema, exc, util
from sqlalchemy.engine import base, default
-from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import compiler, expression, util as sql_util
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
+ affinity = sql_util.determine_date_affinity(extract.expr)
+
+ casts = {sqltypes.Date:'date', sqltypes.DateTime:'timestamp', sqltypes.Interval:'interval', sqltypes.Time:'time'}
+ cast = casts.get(affinity, None)
+ if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
+ expr = extract.expr.op('::')(sql.literal_column(cast))
+ else:
+ expr = extract.expr
return "EXTRACT(%s FROM %s)" % (
- field, self.process(extract.expr.op('::')(sql.literal_column('timestamp'))))
+ field, self.process(expr))
class PGSchemaGenerator(compiler.SchemaGenerator):
-from sqlalchemy import exc, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql, types as sqltypes
from sqlalchemy.sql import expression, operators, visitors
from itertools import chain
else:
return None, None
+_date_affinities = None
+def determine_date_affinity(expr):
+ """Given an expression, determine if it returns 'interval', 'date', or 'datetime'.
+
+ the PG dialect uses this to generate the extract() function.
+
+ It's less than ideal since it basically needs to duplicate PG's
+ date arithmetic rules.
+
+ Rules are based on http://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ Returns None if operators other than + or - are detected as well as types
+ outside of those above.
+
+ """
+
+ global _date_affinities
+ if _date_affinities is None:
+ Date, DateTime, Integer, \
+ Numeric, Interval, Time = \
+ sqltypes.Date, sqltypes.DateTime,\
+ sqltypes.Integer, sqltypes.Numeric,\
+ sqltypes.Interval, sqltypes.Time
+
+ _date_affinities = {
+ operators.add:{
+ (Date, Integer):Date,
+ (Date, Interval):DateTime,
+ (Date, Time):DateTime,
+ (Interval, Interval):Interval,
+ (DateTime, Interval):DateTime,
+ (Interval, Time):Time,
+ },
+ operators.sub:{
+ (Date, Integer):Date,
+ (Date, Interval):DateTime,
+ (Time, Time):Interval,
+ (Time, Interval):Time,
+ (DateTime, Interval):DateTime,
+ (Interval, Interval):Interval,
+ (DateTime, DateTime):Interval,
+ },
+ operators.mul:{
+ (Integer, Interval):Interval,
+ (Interval, Numeric):Interval,
+ },
+ operators.div: {
+ (Interval, Numeric):Interval
+ }
+ }
+
+ if isinstance(expr, expression._BinaryExpression):
+ if expr.operator not in _date_affinities:
+ return None
+
+ left_affin, right_affin = \
+ determine_date_affinity(expr.left), \
+ determine_date_affinity(expr.right)
+
+ if operators.is_commutative(expr.operator):
+ key = tuple(sorted([left_affin, right_affin], key=lambda cls:cls.__name__))
+ else:
+ key = (left_affin, right_affin)
+
+ lookup = _date_affinities[expr.operator]
+ return lookup.get(key, None)
+
+ # work around the fact that expressions put the wrong type
+ # on generated bind params when its "datetime + timedelta"
+ # and similar
+ if isinstance(expr, expression._BindParamClause):
+ type_ = sqltypes.type_map.get(type(expr.value), sqltypes.NullType)()
+ else:
+ type_ = expr.type
+
+ affinities = set([sqltypes.Date, sqltypes.DateTime,
+ sqltypes.Interval, sqltypes.Time, sqltypes.Integer])
+
+ if type_ is not None and type_._type_affinity in affinities:
+ return type_._type_affinity
+ else:
+ return None
+
+
def find_tables(clause, check_columns=False,
include_aliases=False, include_joins=False,
import sqlalchemy.databases.postgres as pg
self.__supported = {pg.PGDialect:pg.PGInterval}
del pg
-
+
+ @property
+ def _type_affinity(self):
+ return Interval
+
def load_dialect_impl(self, dialect):
if dialect.__class__ in self.__supported:
return self.__supported[dialect.__class__]()
self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
def test_extract(self):
- t = table('t', column('col1'))
-
- for field in 'year', 'month', 'day':
- self.assert_compile(
- select([extract(field, t.c.col1)]),
- "SELECT EXTRACT(%s FROM t.col1 :: timestamp) AS anon_1 "
- "FROM t" % field)
-
- for field in 'year', 'month', 'day':
- self.assert_compile(
- select([extract(field, func.timestamp() - datetime.timedelta(days =5))]),
- "SELECT EXTRACT(%s FROM (timestamp() - %%(timestamp_1)s) :: timestamp) AS anon_1"
- % field)
+
+ t = table('t', column('col1', DateTime), column('col2', Date), column('col3', Time))
+
+ for field in 'year', 'month', 'day', 'epoch', 'hour':
+ for expr, compiled_expr in [
+ ( t.c.col1, "t.col1 :: timestamp" ),
+ ( t.c.col2, "t.col2 :: date" ),
+ ( t.c.col3, "t.col3 :: time" ),
+ (func.current_timestamp() - datetime.timedelta(days=5),
+ "(CURRENT_TIMESTAMP - %(current_timestamp_1)s) :: timestamp"
+ ),
+ (func.current_timestamp() + func.current_timestamp(),
+ "CURRENT_TIMESTAMP + CURRENT_TIMESTAMP" # invalid, no cast.
+ ),
+ (text("foo.date + foo.time"),
+ "foo.date + foo.time" # plain text. no cast.
+ ),
+ (func.current_timestamp() + datetime.timedelta(days=5),
+ "(CURRENT_TIMESTAMP + %(current_timestamp_1)s) :: timestamp"
+ ),
+ (t.c.col2 + t.c.col3,
+ "(t.col2 + t.col3) :: timestamp"
+ ),
+ # addition is commutative
+ (t.c.col2 + datetime.timedelta(days=5),
+ "(t.col2 + %(col2_1)s) :: timestamp"
+ ),
+ (datetime.timedelta(days=5) + t.c.col2,
+ "(%(col2_1)s + t.col2) :: timestamp"
+ ),
+ # subtraction is not
+ (t.c.col1 - datetime.timedelta(seconds=30),
+ "(t.col1 - %(col1_1)s) :: timestamp"
+ ),
+ (datetime.timedelta(seconds=30) - t.c.col1,
+ "%(col1_1)s - t.col1" # invalid - no cast.
+ ),
+ (func.coalesce(t.c.col1, func.current_timestamp()),
+ "coalesce(t.col1, CURRENT_TIMESTAMP) :: timestamp"
+ ),
+ (t.c.col3 + datetime.timedelta(seconds=30),
+ "(t.col3 + %(col3_1)s) :: time"
+ ),
+ (func.current_timestamp() - func.coalesce(t.c.col1, func.current_timestamp()),
+ "(CURRENT_TIMESTAMP - coalesce(t.col1, CURRENT_TIMESTAMP)) :: interval",
+ ),
+ (3 * func.foobar(type_=Interval),
+ "(%(foobar_1)s * foobar()) :: interval"
+ ),
+ (literal(datetime.timedelta(seconds=10)) - literal(datetime.timedelta(seconds=10)),
+ "(%(param_1)s - %(param_2)s) :: interval"
+ ),
+ ]:
+ self.assert_compile(
+ select([extract(field, expr)]).select_from(t),
+ "SELECT EXTRACT(%s FROM %s) AS anon_1 FROM t" % (
+ field,
+ compiled_expr
+ )
+ )
class ReturningTest(TestBase, AssertsExecutionResults):
__only_on__ = 'postgres'