]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
extract() is now dialect-sensitive and supports SQLite and others.
authorJason Kirtland <jek@discorporate.us>
Mon, 30 Mar 2009 20:41:48 +0000 (20:41 +0000)
committerJason Kirtland <jek@discorporate.us>
Mon, 30 Mar 2009 20:41:48 +0000 (20:41 +0000)
17 files changed:
CHANGES
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/dialect/access.py
test/dialect/mssql.py
test/dialect/mysql.py
test/dialect/postgres.py
test/dialect/sqlite.py
test/dialect/sybase.py
test/sql/functions.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 8aa1709c6e8b9215def8955ebc512fa50f541f1f..172b2c23c64bfa072f293a21b8c290073331491f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,33 +3,37 @@
 =======
 CHANGES
 =======
+
 0.5.4
 =====
+
 - orm
-    - Fixed the "set collection" function on "dynamic" relations
-      to initiate events correctly.  Previously a collection
-      could only be assigned to a pending parent instance,
-      otherwise modified events would not be fired correctly.
-      Set collection is now compatible with merge(), 
-      fixes [ticket:1352].
-
-    - Lazy loader will not use get() if the "lazy load"
-      SQL clause matches the clause used by get(), but
-      contains some parameters hardcoded.  Previously
-      the lazy strategy would fail with the get().  Ideally
-      get() would be used with the hardcoded parameters
-      but this would require further development.
+    - Fixed the "set collection" function on "dynamic" relations to
+      initiate events correctly.  Previously a collection could only
+      be assigned to a pending parent instance, otherwise modified
+      events would not be fired correctly.  Set collection is now
+      compatible with merge(), fixes [ticket:1352].
+
+    - Lazy loader will not use get() if the "lazy load" SQL clause
+      matches the clause used by get(), but contains some parameters
+      hardcoded.  Previously the lazy strategy would fail with the
+      get().  Ideally get() would be used with the hardcoded
+      parameters but this would require further development.
       [ticket:1357]
 
 - sql
-    - Fixed __repr__() and other _get_colspec() methods on 
+    - ``sqlalchemy.extract()`` is now dialect sensitive and can
+      extract components of timestamps idiomatically across the
+      supported databases, including SQLite.
+
+    - Fixed __repr__() and other _get_colspec() methods on
       ForeignKey constructed from __clause_element__() style
       construct (i.e. declarative columns).  [ticket:1353]
-      
+
 - mssql
     - Corrected problem with information schema not working with a
-      binary collation based database. Cleaned up information
-      schema since it is only used by mssql now. [ticket:1343]
+      binary collation based database. Cleaned up information schema
+      since it is only used by mssql now. [ticket:1343]
 
 0.5.3
 =====
index 67af4a7a4a6808b0adb57a87c57c3fa81d875510..56c28b8cc612152c006995db68de4ed3f252c006 100644 (file)
@@ -328,6 +328,20 @@ class AccessDialect(default.DefaultDialect):
 
 
 class AccessCompiler(compiler.DefaultCompiler):
+    extract_map = compiler.DefaultCompiler.extract_map.copy()
+    extract_map.update ({
+            'month': 'm',
+            'day': 'd',
+            'year': 'yyyy',
+            'second': 's',
+            'hour': 'h',
+            'doy': 'y',
+            'minute': 'n',
+            'quarter': 'q',
+            'dow': 'w',
+            'week': 'ww'
+    })
+
     def visit_select_precolumns(self, select):
         """Access puts TOP, it's version of LIMIT here """
         s = select.distinct and "DISTINCT " or ""
@@ -375,6 +389,10 @@ class AccessCompiler(compiler.DefaultCompiler):
         return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " INNER JOIN ") + \
             self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
 
+    def visit_extract(self, extract):
+        field = self.extract_map.get(extract.field, extract.field)
+        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
 
 class AccessSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
@@ -422,4 +440,4 @@ dialect.schemagenerator = AccessSchemaGenerator
 dialect.schemadropper = AccessSchemaDropper
 dialect.preparer = AccessIdentifierPreparer
 dialect.defaultrunner = AccessDefaultRunner
-dialect.execution_ctx_cls = AccessExecutionContext
\ No newline at end of file
+dialect.execution_ctx_cls = AccessExecutionContext
index 63ec8da15c0e8ae5058b10827fd6968d057624ac..03cf73eee328f1572d8eff04cc87ad9600ad7536 100644 (file)
@@ -1515,6 +1515,14 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         }
     )
 
+    extract_map = compiler.DefaultCompiler.extract_map.copy()
+    extract_map.update ({
+        'doy': 'dayofyear',
+        'dow': 'weekday',
+        'milliseconds': 'millisecond',
+        'microseconds': 'microsecond'
+    })
+
     def __init__(self, *args, **kwargs):
         super(MSSQLCompiler, self).__init__(*args, **kwargs)
         self.tablealiases = {}
@@ -1586,6 +1594,10 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         kwargs['mssql_aliased'] = True
         return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
 
+    def visit_extract(self, extract):
+        field = self.extract_map.get(extract.field, extract.field)
+        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
     def visit_savepoint(self, savepoint_stmt):
         util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
         return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
index 3d71bb72324fbfb58b3b9b2fd334f2f9fd798be3..c2b233a6e9dafc03a901402d9c8301111452169d 100644 (file)
@@ -1914,6 +1914,10 @@ class MySQLCompiler(compiler.DefaultCompiler):
         "utc_timestamp":"UTC_TIMESTAMP"
         })
 
+    extract_map = compiler.DefaultCompiler.extract_map.copy()
+    extract_map.update ({
+        'milliseconds': 'millisecond',
+    })
 
     def visit_typeclause(self, typeclause):
         type_ = typeclause.type.dialect_impl(self.dialect)
index 038a9e8df47aed5e2b53f8a206ae2a95fea4ce39..068afaf3dd2b786bf35cdae76502a0735516472e 100644 (file)
@@ -792,6 +792,12 @@ class PGCompiler(compiler.DefaultCompiler):
         else:
             return text
 
+    def visit_extract(self, extract, **kwargs):
+        field = self.extract_map.get(extract.field, extract.field)
+        return "EXTRACT(%s FROM %s::timestamp)" % (
+            field, self.process(extract.expr))
+
+
 class PGSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
index 77eb30ff81012957260f0205f812a12f0a03cf31..b77a315b80ebee81ffa443a20d66cfe2760dba25 100644 (file)
@@ -557,12 +557,34 @@ class SQLiteCompiler(compiler.DefaultCompiler):
         }
     )
 
+    extract_map = compiler.DefaultCompiler.extract_map.copy()
+    extract_map.update({
+        'month': '%m',
+        'day': '%d',
+        'year': '%Y',
+        'second': '%S',
+        'hour': '%H',
+        'doy': '%j',
+        'minute': '%M',
+        'epoch': '%s',
+        'dow': '%w',
+        'week': '%W'
+    })
+
     def visit_cast(self, cast, **kwargs):
         if self.dialect.supports_cast:
             return super(SQLiteCompiler, self).visit_cast(cast)
         else:
             return self.process(cast.clause)
 
+    def visit_extract(self, extract):
+        try:
+            return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
+                self.extract_map[extract.field], self.process(extract.expr))
+        except KeyError:
+            raise exc.ArgumentError(
+                "%s is not a valid extract argument." % extract.field)
+
     def limit_clause(self, select):
         text = ""
         if select._limit is not None:
index 6007315f264334c502eaf09565bd3005da5d9f0c..f5b48e1479c23776a6e20b574341ad725099b9a6 100644 (file)
@@ -733,6 +733,14 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
         sql_operators.mod: lambda x, y: "MOD(%s, %s)" % (x, y),
     })
 
+    extract_map = compiler.DefaultCompiler.extract_map.copy()
+    extract_map.update ({
+        'doy': 'dayofyear',
+        'dow': 'weekday',
+        'milliseconds': 'millisecond'
+    })
+
+
     def bindparam_string(self, name):
         res = super(SybaseSQLCompiler, self).bindparam_string(name)
         if name.lower().startswith('literal'):
@@ -786,6 +794,10 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
             res = "CAST(%s AS %s)" % (res, self.process(cast.typeclause))
         return res
 
+    def visit_extract(self, extract):
+        field = self.extract_map.get(extract.field, extract.field)
+        return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
+
     def for_update_clause(self, select):
         # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
         return ''
index 3a982f23c43e32c296040bfcff17f468c7d939ba..5042959b25443355449fa3abfaffc6f62c2cd4a0 100644 (file)
@@ -108,6 +108,23 @@ FUNCTIONS = {
     functions.user: 'USER'
 }
 
+EXTRACT_MAP = {
+    'month': 'month',
+    'day': 'day',
+    'year': 'year',
+    'second': 'second',
+    'hour': 'hour',
+    'doy': 'doy',
+    'minute': 'minute',
+    'quarter': 'quarter',
+    'dow': 'dow',
+    'week': 'week',
+    'epoch': 'epoch',
+    'milliseconds': 'milliseconds',
+    'microseconds': 'microseconds',
+    'timezone_hour': 'timezone_hour',
+    'timezone_minute': 'timezone_minute'
+}
 
 class _CompileLabel(visitors.Visitable):
     """lightweight label object which acts as an expression._Label."""
@@ -133,6 +150,7 @@ class DefaultCompiler(engine.Compiled):
 
     operators = OPERATORS
     functions = FUNCTIONS
+    extract_map = EXTRACT_MAP
 
     # if we are insert/update/delete. 
     # set to true when we visit an INSERT, UPDATE or DELETE
@@ -346,6 +364,10 @@ class DefaultCompiler(engine.Compiled):
     def visit_cast(self, cast, **kwargs):
         return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
 
+    def visit_extract(self, extract, **kwargs):
+        field = self.extract_map.get(extract.field, extract.field)
+        return "EXTRACT(%s FROM %s)" % (field, self.process(extract.expr))
+
     def visit_function(self, func, result_map=None, **kwargs):
         if result_map is not None:
             result_map[func.name.lower()] = (func.name, None, func.type)
index 5a0d5b04309b788043af5e023c39b717be322b5d..56f358db8277530b4499067ea76b2a96e488f9b5 100644 (file)
@@ -484,8 +484,7 @@ def cast(clause, totype, **kwargs):
 def extract(field, expr):
     """Return the clause ``extract(field FROM expr)``."""
 
-    expr = _BinaryExpression(text(field), expr, operators.from_)
-    return func.extract(expr)
+    return _Extract(field, expr)
 
 def collate(expression, collation):
     """Return the clause ``expression COLLATE collation``."""
@@ -2313,6 +2312,27 @@ class _Cast(ColumnElement):
         return self.clause._from_objects
 
 
+class _Extract(ColumnElement):
+
+    __visit_name__ = 'extract'
+
+    def __init__(self, field, expr, **kwargs):
+        self.type = sqltypes.Integer()
+        self.field = field
+        self.expr = _literal_as_binds(expr, None)
+
+    def _copy_internals(self, clone=_clone):
+        self.field = clone(self.field)
+        self.expr = clone(self.expr)
+
+    def get_children(self, **kwargs):
+        return self.field, self.expr
+
+    @property
+    def _from_objects(self):
+        return self.expr._from_objects
+
+
 class _UnaryExpression(ColumnElement):
 
     __visit_name__ = 'unary'
index 311231947e567b005dabfb9dd75eacd63c9eab7e..57af45a9d6ac1642ff4cef5c50e32e542d0318f6 100644 (file)
@@ -1,14 +1,33 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
+from sqlalchemy import sql
 from sqlalchemy.databases import access
 from testlib import *
 
 
-class BasicTest(TestBase, AssertsExecutionResults):
-    # A simple import of the database/ module should work on all systems.
-    def test_import(self):
-        # we got this far, right?
-        return True
+class CompileTest(TestBase, AssertsCompiledSQL):
+    __dialect__ = access.dialect()
+
+    def test_extract(self):
+        t = sql.table('t', sql.column('col1'))
+
+        mapping = {
+            'month': 'm',
+            'day': 'd',
+            'year': 'yyyy',
+            'second': 's',
+            'hour': 'h',
+            'doy': 'y',
+            'minute': 'n',
+            'quarter': 'q',
+            'dow': 'w',
+            'week': 'ww'
+            }
+
+        for field, subst in mapping.items():
+            self.assert_compile(
+                select([extract(field, t.c.col1)]),
+                'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst)
 
 
 if __name__ == "__main__":
index 0962f59f3ba6a39fa73f6cdd21e2a69a49043d11..de9c5cd62bfea0b7f1217d1556a230f7b09390d7 100755 (executable)
@@ -125,6 +125,14 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(func.current_date(), "GETDATE()")
         self.assert_compile(func.length(3), "LEN(:length_1)")
 
+    def test_extract(self):
+        t = table('t', column('col1'))
+
+        for field in 'day', 'month', 'year':
+            self.assert_compile(
+                select([extract(field, t.c.col1)]),
+                'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % field)
+
 
 class IdentityInsertTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
index a233c25f544eda3eca601a581a89b7d18c3ae525..fa8a85ec453d312c20b8928af5e2ea204efd62b8 100644 (file)
@@ -982,6 +982,19 @@ class SQLTest(TestBase, AssertsCompiledSQL):
         for type_, expected in specs:
             self.assert_compile(cast(t.c.col, type_), expected)
 
+    def test_extract(self):
+        t = sql.table('t', sql.column('col1'))
+
+        for field in 'year', 'month', 'day':
+            self.assert_compile(
+                select([extract(field, t.c.col1)]),
+                "SELECT EXTRACT(%s FROM t.col1) AS anon_1 FROM t" % field)
+
+        # millsecondS to millisecond
+        self.assert_compile(
+            select([extract('milliseconds', t.c.col1)]),
+            "SELECT EXTRACT(millisecond FROM t.col1) AS anon_1 FROM t")
+
 
 class RawReflectionTest(TestBase):
     def setUp(self):
index 3867d1b019406644d3258292c77e44d9efced1a9..d613ad2ddf0e522396e1b88976ffc30d6fad89b8 100644 (file)
@@ -22,6 +22,8 @@ class SequenceTest(TestBase, AssertsCompiledSQL):
         assert dialect.identifier_preparer.format_sequence(seq) == '"Some_Schema"."My_Seq"'
 
 class CompileTest(TestBase, AssertsCompiledSQL):
+    __dialect__ = postgres.dialect()
+
     def test_update_returning(self):
         dialect = postgres.dialect()
         table1 = table('mytable',
@@ -58,6 +60,16 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         i = insert(table1, values=dict(name='foo'), postgres_returning=[func.length(table1.c.name)])
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING length(mytable.name)", dialect=dialect)
 
+    def test_extract(self):
+        t = table('t', column('col1'))
+
+        for field in 'year', 'month', 'day':
+            self.assert_compile(
+                select([extract(field, t.c.col1)]),
+                "SELECT EXTRACT(%s FROM t.col1::timestamp) AS anon_1 "
+                "FROM t" % field)
+
+
 class ReturningTest(TestBase, AssertsExecutionResults):
     __only_on__ = 'postgres'
 
index 97d12bf60357e1fb7ec5bdd08e09577ae1c0931e..005fad66b4652bcb45d593f02d326b03990d23f9 100644 (file)
@@ -3,7 +3,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exc
+from sqlalchemy import exc, sql
 from sqlalchemy.databases import sqlite
 from testlib import *
 
@@ -283,6 +283,36 @@ class DialectTest(TestBase, AssertsExecutionResults):
                 pass
             raise
 
+
+class SQLTest(TestBase, AssertsCompiledSQL):
+    """Tests SQLite-dialect specific compilation."""
+
+    __dialect__ = sqlite.dialect()
+
+
+    def test_extract(self):
+        t = sql.table('t', sql.column('col1'))
+
+        mapping = {
+            'month': '%m',
+            'day': '%d',
+            'year': '%Y',
+            'second': '%S',
+            'hour': '%H',
+            'doy': '%j',
+            'minute': '%M',
+            'epoch': '%s',
+            'dow': '%w',
+            'week': '%W',
+            }
+
+        for field, subst in mapping.items():
+            self.assert_compile(
+                select([extract(field, t.c.col1)]),
+                "SELECT CAST(STRFTIME('%s', t.col1) AS INTEGER) AS anon_1 "
+                "FROM t" % subst)
+
+
 class InsertTest(TestBase, AssertsExecutionResults):
     """Tests inserts and autoincrement."""
 
index 19cca465bd0cf47bc1dc4960f41f82c65c727fbc..32b9904d8a5d6bf22dd49deb41b6c29bd3a0b48b 100644 (file)
@@ -1,14 +1,30 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
+from sqlalchemy import sql
 from sqlalchemy.databases import sybase
 from testlib import *
 
 
-class BasicTest(TestBase, AssertsExecutionResults):
-    # A simple import of the database/ module should work on all systems.
-    def test_import(self):
-        # we got this far, right?
-        return True
+class CompileTest(TestBase, AssertsCompiledSQL):
+    __dialect__ = sybase.dialect()
+
+    def test_extract(self):
+        t = sql.table('t', sql.column('col1'))
+
+        mapping = {
+            'day': 'day',
+            'doy': 'dayofyear',
+            'dow': 'weekday',
+            'milliseconds': 'millisecond',
+            'millisecond': 'millisecond',
+            'year': 'year',
+            }
+
+        for field, subst in mapping.items():
+            self.assert_compile(
+                select([extract(field, t.c.col1)]),
+                'SELECT DATEPART("%s", t.col1) AS anon_1 FROM t' % subst)
+
 
 
 if __name__ == "__main__":
index 1519575036901bcb97588f3d96316f739ec6d9d1..17d8a35e97464ec893679aa6cd80a95fb167e72c 100644 (file)
@@ -271,6 +271,44 @@ class ExecuteTest(TestBase):
 
         assert x == y == z == w == q == r
 
+    def test_extract_bind(self):
+        """Basic common denominator execution tests for extract()"""
+
+        date = datetime.date(2010, 5, 1)
+
+        def execute(field):
+            return testing.db.execute(select([extract(field, date)])).scalar()
+
+        assert execute('year') == 2010
+        assert execute('month') == 5
+        assert execute('day') == 1
+
+        date = datetime.datetime(2010, 5, 1, 12, 11, 10)
+
+        assert execute('year') == 2010
+        assert execute('month') == 5
+        assert execute('day') == 1
+
+    def test_extract_expression(self):
+        meta = MetaData(testing.db)
+        table = Table('test', meta,
+                      Column('dt', DateTime),
+                      Column('d', Date))
+        meta.create_all()
+        try:
+            table.insert().execute(
+                {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10),
+                 'd': datetime.date(2010, 5, 1) })
+            rs = select([extract('year', table.c.dt),
+                         extract('month', table.c.d)]).execute()
+            row = rs.fetchone()
+            assert row[0] == 2010
+            assert row[1] == 5
+            rs.close()
+        finally:
+            meta.drop_all()
+
+
 def exec_sorted(statement, *args, **kw):
     """Executes a statement and returns a sorted list plain tuple rows."""
 
index 15c47a6747860278fdfd9465e7c9fc1b03a03dd2..52c382f817692ebfd5b7d752492da3baddbcc6e1 100644 (file)
@@ -831,12 +831,6 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
              "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :param_1"
          )
 
-    def test_extract(self):
-        """test the EXTRACT function"""
-        self.assert_compile(select([extract("month", table3.c.otherstuff)]), "SELECT extract(month FROM thirdtable.otherstuff) AS extract_1 FROM thirdtable")
-
-        self.assert_compile(select([extract("day", func.to_date("03/20/2005", "MM/DD/YYYY"))]), "SELECT extract(day FROM to_date(:to_date_1, :to_date_2)) AS extract_1")
-
     def test_collate(self):
         for expr in (select([table1.c.name.collate('latin1_german2_ci')]),
                      select([collate(table1.c.name, 'latin1_german2_ci')])):