From: Jason Kirtland Date: Mon, 30 Mar 2009 20:41:48 +0000 (+0000) Subject: extract() is now dialect-sensitive and supports SQLite and others. X-Git-Tag: rel_0_5_4~42 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=aca84bebb091a51ceeb911249c366e17b954826a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git extract() is now dialect-sensitive and supports SQLite and others. --- diff --git a/CHANGES b/CHANGES index 8aa1709c6e..172b2c23c6 100644 --- a/CHANGES +++ b/CHANGES @@ -3,33 +3,37 @@ ======= CHANGES ======= + 0.5.4 ===== + - orm - - Fixed the "set collection" function on "dynamic" relations - to initiate events correctly. Previously a collection - could only be assigned to a pending parent instance, - otherwise modified events would not be fired correctly. - Set collection is now compatible with merge(), - fixes [ticket:1352]. - - - Lazy loader will not use get() if the "lazy load" - SQL clause matches the clause used by get(), but - contains some parameters hardcoded. Previously - the lazy strategy would fail with the get(). Ideally - get() would be used with the hardcoded parameters - but this would require further development. + - Fixed the "set collection" function on "dynamic" relations to + initiate events correctly. Previously a collection could only + be assigned to a pending parent instance, otherwise modified + events would not be fired correctly. Set collection is now + compatible with merge(), fixes [ticket:1352]. + + - Lazy loader will not use get() if the "lazy load" SQL clause + matches the clause used by get(), but contains some parameters + hardcoded. Previously the lazy strategy would fail with the + get(). Ideally get() would be used with the hardcoded + parameters but this would require further development. [ticket:1357] - sql - - Fixed __repr__() and other _get_colspec() methods on + - ``sqlalchemy.extract()`` is now dialect sensitive and can + extract components of timestamps idiomatically across the + supported databases, including SQLite. + + - Fixed __repr__() and other _get_colspec() methods on ForeignKey constructed from __clause_element__() style construct (i.e. declarative columns). [ticket:1353] - + - mssql - Corrected problem with information schema not working with a - binary collation based database. Cleaned up information - schema since it is only used by mssql now. [ticket:1343] + binary collation based database. Cleaned up information schema + since it is only used by mssql now. [ticket:1343] 0.5.3 ===== diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 67af4a7a4a..56c28b8cc6 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -328,6 +328,20 @@ class AccessDialect(default.DefaultDialect): class AccessCompiler(compiler.DefaultCompiler): + extract_map = compiler.DefaultCompiler.extract_map.copy() + extract_map.update ({ + 'month': 'm', + 'day': 'd', + 'year': 'yyyy', + 'second': 's', + 'hour': 'h', + 'doy': 'y', + 'minute': 'n', + 'quarter': 'q', + 'dow': 'w', + 'week': 'ww' + }) + def visit_select_precolumns(self, select): """Access puts TOP, it's version of LIMIT here """ s = select.distinct and "DISTINCT " or "" @@ -375,6 +389,10 @@ class AccessCompiler(compiler.DefaultCompiler): return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + class AccessSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): @@ -422,4 +440,4 @@ dialect.schemagenerator = AccessSchemaGenerator dialect.schemadropper = AccessSchemaDropper dialect.preparer = AccessIdentifierPreparer dialect.defaultrunner = AccessDefaultRunner -dialect.execution_ctx_cls = AccessExecutionContext \ No newline at end of file +dialect.execution_ctx_cls = AccessExecutionContext diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 63ec8da15c..03cf73eee3 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -1515,6 +1515,14 @@ class MSSQLCompiler(compiler.DefaultCompiler): } ) + extract_map = compiler.DefaultCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'microseconds': 'microsecond' + }) + def __init__(self, *args, **kwargs): super(MSSQLCompiler, self).__init__(*args, **kwargs) self.tablealiases = {} @@ -1586,6 +1594,10 @@ class MSSQLCompiler(compiler.DefaultCompiler): kwargs['mssql_aliased'] = True return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + def visit_savepoint(self, savepoint_stmt): util.warn("Savepoint support in mssql is experimental and may lead to data loss.") return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 3d71bb7232..c2b233a6e9 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1914,6 +1914,10 @@ class MySQLCompiler(compiler.DefaultCompiler): "utc_timestamp":"UTC_TIMESTAMP" }) + extract_map = compiler.DefaultCompiler.extract_map.copy() + extract_map.update ({ + 'milliseconds': 'millisecond', + }) def visit_typeclause(self, typeclause): type_ = typeclause.type.dialect_impl(self.dialect) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 038a9e8df4..068afaf3dd 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -792,6 +792,12 @@ class PGCompiler(compiler.DefaultCompiler): else: return text + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s::timestamp)" % ( + field, self.process(extract.expr)) + + class PGSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 77eb30ff81..b77a315b80 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -557,12 +557,34 @@ class SQLiteCompiler(compiler.DefaultCompiler): } ) + extract_map = compiler.DefaultCompiler.extract_map.copy() + extract_map.update({ + 'month': '%m', + 'day': '%d', + 'year': '%Y', + 'second': '%S', + 'hour': '%H', + 'doy': '%j', + 'minute': '%M', + 'epoch': '%s', + 'dow': '%w', + 'week': '%W' + }) + def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: return super(SQLiteCompiler, self).visit_cast(cast) else: return self.process(cast.clause) + def visit_extract(self, extract): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], self.process(extract.expr)) + except KeyError: + raise exc.ArgumentError( + "%s is not a valid extract argument." % extract.field) + def limit_clause(self, select): text = "" if select._limit is not None: diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index 6007315f26..f5b48e1479 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -733,6 +733,14 @@ class SybaseSQLCompiler(compiler.DefaultCompiler): sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y), }) + extract_map = compiler.DefaultCompiler.extract_map.copy() + extract_map.update ({ + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond' + }) + + def bindparam_string(self, name): res = super(SybaseSQLCompiler, self).bindparam_string(name) if name.lower().startswith('literal'): @@ -786,6 +794,10 @@ class SybaseSQLCompiler(compiler.DefaultCompiler): res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause)) return res + def visit_extract(self, extract): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr)) + def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use return '' diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3a982f23c4..5042959b25 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -108,6 +108,23 @@ FUNCTIONS = { functions.user: 'USER' } +EXTRACT_MAP = { + 'month': 'month', + 'day': 'day', + 'year': 'year', + 'second': 'second', + 'hour': 'hour', + 'doy': 'doy', + 'minute': 'minute', + 'quarter': 'quarter', + 'dow': 'dow', + 'week': 'week', + 'epoch': 'epoch', + 'milliseconds': 'milliseconds', + 'microseconds': 'microseconds', + 'timezone_hour': 'timezone_hour', + 'timezone_minute': 'timezone_minute' +} class _CompileLabel(visitors.Visitable): """lightweight label object which acts as an expression._Label.""" @@ -133,6 +150,7 @@ class DefaultCompiler(engine.Compiled): operators = OPERATORS functions = FUNCTIONS + extract_map = EXTRACT_MAP # if we are insert/update/delete. # set to true when we visit an INSERT, UPDATE or DELETE @@ -346,6 +364,10 @@ class DefaultCompiler(engine.Compiled): def visit_cast(self, cast, **kwargs): return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause)) + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr)) + def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: result_map[func.name.lower()] = (func.name, None, func.type) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 5a0d5b0430..56f358db82 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -484,8 +484,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryExpression(text(field), expr, operators.from_) - return func.extract(expr) + return _Extract(field, expr) def collate(expression, collation): """Return the clause ``expression COLLATE collation``.""" @@ -2313,6 +2312,27 @@ class _Cast(ColumnElement): return self.clause._from_objects +class _Extract(ColumnElement): + + __visit_name__ = 'extract' + + def __init__(self, field, expr, **kwargs): + self.type = sqltypes.Integer() + self.field = field + self.expr = _literal_as_binds(expr, None) + + def _copy_internals(self, clone=_clone): + self.field = clone(self.field) + self.expr = clone(self.expr) + + def get_children(self, **kwargs): + return self.field, self.expr + + @property + def _from_objects(self): + return self.expr._from_objects + + class _UnaryExpression(ColumnElement): __visit_name__ = 'unary' diff --git a/test/dialect/access.py b/test/dialect/access.py index 311231947e..57af45a9d6 100644 --- a/test/dialect/access.py +++ b/test/dialect/access.py @@ -1,14 +1,33 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import sql from sqlalchemy.databases import access from testlib import * -class BasicTest(TestBase, AssertsExecutionResults): - # A simple import of the database/ module should work on all systems. - def test_import(self): - # we got this far, right? - return True +class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = access.dialect() + + def test_extract(self): + t = sql.table('t', sql.column('col1')) + + mapping = { + 'month': 'm', + 'day': 'd', + 'year': 'yyyy', + 'second': 's', + 'hour': 'h', + 'doy': 'y', + 'minute': 'n', + 'quarter': 'q', + 'dow': 'w', + 'week': 'ww' + } + + for field, subst in mapping.items(): + self.assert_compile( + select([extract(field, t.c.col1)]), + 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst) if __name__ == "__main__": diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index 0962f59f3b..de9c5cd62b 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -125,6 +125,14 @@ class CompileTest(TestBase, AssertsCompiledSQL): self.assert_compile(func.current_date(), "GETDATE()") self.assert_compile(func.length(3), "LEN(:length_1)") + def test_extract(self): + t = table('t', column('col1')) + + for field in 'day', 'month', 'year': + self.assert_compile( + select([extract(field, t.c.col1)]), + 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field) + class IdentityInsertTest(TestBase, AssertsCompiledSQL): __only_on__ = 'mssql' diff --git a/test/dialect/mysql.py b/test/dialect/mysql.py index a233c25f54..fa8a85ec45 100644 --- a/test/dialect/mysql.py +++ b/test/dialect/mysql.py @@ -982,6 +982,19 @@ class SQLTest(TestBase, AssertsCompiledSQL): for type_, expected in specs: self.assert_compile(cast(t.c.col, type_), expected) + def test_extract(self): + t = sql.table('t', sql.column('col1')) + + for field in 'year', 'month', 'day': + self.assert_compile( + select([extract(field, t.c.col1)]), + "SELECT EXTRACT(%s FROM t.col1) AS anon_1 FROM t" % field) + + # millsecondS to millisecond + self.assert_compile( + select([extract('milliseconds', t.c.col1)]), + "SELECT EXTRACT(millisecond FROM t.col1) AS anon_1 FROM t") + class RawReflectionTest(TestBase): def setUp(self): diff --git a/test/dialect/postgres.py b/test/dialect/postgres.py index 3867d1b019..d613ad2ddf 100644 --- a/test/dialect/postgres.py +++ b/test/dialect/postgres.py @@ -22,6 +22,8 @@ class SequenceTest(TestBase, AssertsCompiledSQL): assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"' class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = postgres.dialect() + def test_update_returning(self): dialect = postgres.dialect() table1 = table('mytable', @@ -58,6 +60,16 @@ class CompileTest(TestBase, AssertsCompiledSQL): i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)]) self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect) + def test_extract(self): + t = table('t', column('col1')) + + for field in 'year', 'month', 'day': + self.assert_compile( + select([extract(field, t.c.col1)]), + "SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 " + "FROM t" % field) + + class ReturningTest(TestBase, AssertsExecutionResults): __only_on__ = 'postgres' diff --git a/test/dialect/sqlite.py b/test/dialect/sqlite.py index 97d12bf603..005fad66b4 100644 --- a/test/dialect/sqlite.py +++ b/test/dialect/sqlite.py @@ -3,7 +3,7 @@ import testenv; testenv.configure_for_tests() import datetime from sqlalchemy import * -from sqlalchemy import exc +from sqlalchemy import exc, sql from sqlalchemy.databases import sqlite from testlib import * @@ -283,6 +283,36 @@ class DialectTest(TestBase, AssertsExecutionResults): pass raise + +class SQLTest(TestBase, AssertsCompiledSQL): + """Tests SQLite-dialect specific compilation.""" + + __dialect__ = sqlite.dialect() + + + def test_extract(self): + t = sql.table('t', sql.column('col1')) + + mapping = { + 'month': '%m', + 'day': '%d', + 'year': '%Y', + 'second': '%S', + 'hour': '%H', + 'doy': '%j', + 'minute': '%M', + 'epoch': '%s', + 'dow': '%w', + 'week': '%W', + } + + for field, subst in mapping.items(): + self.assert_compile( + select([extract(field, t.c.col1)]), + "SELECT CAST(STRFTIME('%s', t.col1) AS INTEGER) AS anon_1 " + "FROM t" % subst) + + class InsertTest(TestBase, AssertsExecutionResults): """Tests inserts and autoincrement.""" diff --git a/test/dialect/sybase.py b/test/dialect/sybase.py index 19cca465bd..32b9904d8a 100644 --- a/test/dialect/sybase.py +++ b/test/dialect/sybase.py @@ -1,14 +1,30 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * +from sqlalchemy import sql from sqlalchemy.databases import sybase from testlib import * -class BasicTest(TestBase, AssertsExecutionResults): - # A simple import of the database/ module should work on all systems. - def test_import(self): - # we got this far, right? - return True +class CompileTest(TestBase, AssertsCompiledSQL): + __dialect__ = sybase.dialect() + + def test_extract(self): + t = sql.table('t', sql.column('col1')) + + mapping = { + 'day': 'day', + 'doy': 'dayofyear', + 'dow': 'weekday', + 'milliseconds': 'millisecond', + 'millisecond': 'millisecond', + 'year': 'year', + } + + for field, subst in mapping.items(): + self.assert_compile( + select([extract(field, t.c.col1)]), + 'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst) + if __name__ == "__main__": diff --git a/test/sql/functions.py b/test/sql/functions.py index 1519575036..17d8a35e97 100644 --- a/test/sql/functions.py +++ b/test/sql/functions.py @@ -271,6 +271,44 @@ class ExecuteTest(TestBase): assert x == y == z == w == q == r + def test_extract_bind(self): + """Basic common denominator execution tests for extract()""" + + date = datetime.date(2010, 5, 1) + + def execute(field): + return testing.db.execute(select([extract(field, date)])).scalar() + + assert execute('year') == 2010 + assert execute('month') == 5 + assert execute('day') == 1 + + date = datetime.datetime(2010, 5, 1, 12, 11, 10) + + assert execute('year') == 2010 + assert execute('month') == 5 + assert execute('day') == 1 + + def test_extract_expression(self): + meta = MetaData(testing.db) + table = Table('test', meta, + Column('dt', DateTime), + Column('d', Date)) + meta.create_all() + try: + table.insert().execute( + {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10), + 'd': datetime.date(2010, 5, 1) }) + rs = select([extract('year', table.c.dt), + extract('month', table.c.d)]).execute() + row = rs.fetchone() + assert row[0] == 2010 + assert row[1] == 5 + rs.close() + finally: + meta.drop_all() + + def exec_sorted(statement, *args, **kw): """Executes a statement and returns a sorted list plain tuple rows.""" diff --git a/test/sql/select.py b/test/sql/select.py index 15c47a6747..52c382f817 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -831,12 +831,6 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :param_1" ) - def test_extract(self): - """test the EXTRACT function""" - self.assert_compile(select([extract("month", table3.c.otherstuff)]), "SELECT extract(month FROM thirdtable.otherstuff) AS extract_1 FROM thirdtable") - - self.assert_compile(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date_1, :to_date_2)) AS extract_1") - def test_collate(self): for expr in (select([table1.c.name.collate('latin1_german2_ci')]), select([collate(table1.c.name, 'latin1_german2_ci')])):