]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The typing system now handles the task of rendering "literal bind" values,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Oct 2013 20:59:56 +0000 (16:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Oct 2013 20:59:56 +0000 (16:59 -0400)
e.g. values that are normally bound parameters but due to context must
be rendered as strings, typically within DDL constructs such as
CHECK constraints and indexes (note that "literal bind" values
become used by DDL as of :ticket:`2742`).  A new method
:meth:`.TypeEngine.literal_processor` serves as the base, and
:meth:`.TypeDecorator.process_literal_param` is added to allow wrapping
of a native literal rendering method. [ticket:2838]
- enhance _get_colparams so that we can send flags like literal_binds into
INSERT statements
- add support in PG for inspecting standard_conforming_strings
- add a new series of roundtrip tests based on INSERT of literal plus SELECT
for basic literal rendering in dialect suite

12 files changed:
doc/build/changelog/changelog_09.rst
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_types.py
lib/sqlalchemy/types.py
test/requirements.py
test/sql/test_types.py

index 51ba9225f6077cb1433c898d8d324fdc425b2436..c6901c14a72474a11640aed61fceb20a6b9734c4 100644 (file)
 .. changelog::
     :version: 0.9.0
 
+    .. change::
+        :tags: feature, sql
+        :tickets: 2838
+
+        The typing system now handles the task of rendering "literal bind" values,
+        e.g. values that are normally bound parameters but due to context must
+        be rendered as strings, typically within DDL constructs such as
+        CHECK constraints and indexes (note that "literal bind" values
+        become used by DDL as of :ticket:`2742`).  A new method
+        :meth:`.TypeEngine.literal_processor` serves as the base, and
+        :meth:`.TypeDecorator.process_literal_param` is added to allow wrapping
+        of a native literal rendering method.
+
     .. change::
         :tags: feature, sql
         :tickets: 2716
index fdb6e3b4a6a7c31170210bd61bd9e9cb0b03bb46..55c6b315a108d108bfd244ae92121146a5eeccdf 100644 (file)
@@ -210,7 +210,7 @@ import re
 
 from ... import sql, schema, exc, util
 from ...engine import default, reflection
-from ...sql import compiler, expression, util as sql_util, operators
+from ...sql import compiler, expression, operators
 from ... import types as sqltypes
 
 try:
@@ -954,25 +954,30 @@ class PGCompiler(compiler.SQLCompiler):
 
     def visit_ilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
+
         return '%s ILIKE %s' % \
                 (self.process(binary.left, **kw),
                     self.process(binary.right, **kw)) \
-                + (escape and
-                        (' ESCAPE ' + self.render_literal_value(escape, None))
-                        or '')
+            + (
+                ' ESCAPE ' +
+                self.render_literal_value(escape, sqltypes.STRINGTYPE)
+                if escape else ''
+            )
 
     def visit_notilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT ILIKE %s' % \
                 (self.process(binary.left, **kw),
                     self.process(binary.right, **kw)) \
-                + (escape and
-                        (' ESCAPE ' + self.render_literal_value(escape, None))
-                        or '')
+            + (
+                ' ESCAPE ' +
+                self.render_literal_value(escape, sqltypes.STRINGTYPE)
+                if escape else ''
+            )
 
     def render_literal_value(self, value, type_):
         value = super(PGCompiler, self).render_literal_value(value, type_)
-        # TODO: need to inspect "standard_conforming_strings"
+
         if self.dialect._backslash_escapes:
             value = value.replace('\\', '\\\\')
         return value
@@ -1357,7 +1362,6 @@ class PGDialect(default.DefaultDialect):
     inspector = PGInspector
     isolation_level = None
 
-    # TODO: need to inspect "standard_conforming_strings"
     _backslash_escapes = True
 
     def __init__(self, isolation_level=None, **kwargs):
@@ -1379,6 +1383,9 @@ class PGDialect(default.DefaultDialect):
         # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
         self.supports_smallserial = self.server_version_info >= (9, 2)
 
+        self._backslash_escapes = connection.scalar(
+                                    "show standard_conforming_strings"
+                                    ) == 'off'
 
     def on_connect(self):
         if self.isolation_level is not None:
index fb7d968be48920bde8dcdae50f6030f0c04b04af..00c9801033514bd9d2e5ab2789bb371be7da56bf 100644 (file)
@@ -160,6 +160,13 @@ class _DateTimeMixin(object):
             kw["regexp"] = self._reg
         return util.constructor_copy(self, cls, **kw)
 
+    def literal_processor(self, dialect):
+        bp = self.bind_processor(dialect)
+        def process(value):
+            return "'%s'" % bp(value)
+        return process
+
+
 class DATETIME(_DateTimeMixin, sqltypes.DateTime):
     """Represent a Python datetime object in SQLite using a string.
 
@@ -211,6 +218,7 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime):
                 "%(hour)02d:%(minute)02d:%(second)02d"
             )
 
+
     def bind_processor(self, dialect):
         datetime_datetime = datetime.datetime
         datetime_date = datetime.date
index 22906af5410bf487605dd7d8c562b891172f7d96..5c7a29f99ce7ffdc1363ac45e22209a44b79b803 100644 (file)
@@ -827,7 +827,7 @@ class SQLCompiler(Compiled):
 
     @util.memoized_property
     def _like_percent_literal(self):
-        return elements.literal_column("'%'", type_=sqltypes.String())
+        return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
 
     def visit_contains_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
@@ -871,39 +871,49 @@ class SQLCompiler(Compiled):
 
     def visit_like_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
+
+        # TODO: use ternary here, not "and"/ "or"
         return '%s LIKE %s' % (
                             binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and
-                    (' ESCAPE ' + self.render_literal_value(escape, None))
-                    or '')
+            + (
+                ' ESCAPE ' +
+                self.render_literal_value(escape, sqltypes.STRINGTYPE)
+                if escape else ''
+            )
 
     def visit_notlike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s NOT LIKE %s' % (
                             binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and
-                    (' ESCAPE ' + self.render_literal_value(escape, None))
-                    or '')
+            + (
+                ' ESCAPE ' +
+                self.render_literal_value(escape, sqltypes.STRINGTYPE)
+                if escape else ''
+            )
 
     def visit_ilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) LIKE lower(%s)' % (
                             binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and
-                    (' ESCAPE ' + self.render_literal_value(escape, None))
-                    or '')
+            + (
+                ' ESCAPE ' +
+                self.render_literal_value(escape, sqltypes.STRINGTYPE)
+                if escape else ''
+            )
 
     def visit_notilike_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return 'lower(%s) NOT LIKE lower(%s)' % (
                             binary.left._compiler_dispatch(self, **kw),
                             binary.right._compiler_dispatch(self, **kw)) \
-            + (escape and
-                    (' ESCAPE ' + self.render_literal_value(escape, None))
-                    or '')
+            + (
+                ' ESCAPE ' +
+                self.render_literal_value(escape, sqltypes.STRINGTYPE)
+                if escape else ''
+            )
 
     def visit_bindparam(self, bindparam, within_columns_clause=False,
                                             literal_binds=False,
@@ -954,9 +964,6 @@ class SQLCompiler(Compiled):
 
     def render_literal_bindparam(self, bindparam, **kw):
         value = bindparam.value
-        processor = bindparam.type._cached_bind_processor(self.dialect)
-        if processor:
-            value = processor(value)
         return self.render_literal_value(value, bindparam.type)
 
     def render_literal_value(self, value, type_):
@@ -969,22 +976,10 @@ class SQLCompiler(Compiled):
         of the DBAPI.
 
         """
-        if isinstance(value, util.string_types):
-            value = value.replace("'", "''")
-            return "'%s'" % value
-        elif value is None:
-            return "NULL"
-        elif isinstance(value, (float, ) + util.int_types):
-            return repr(value)
-        elif isinstance(value, decimal.Decimal):
-            return str(value)
-        elif isinstance(value, util.binary_type):
-            # only would occur on py3k b.c. on 2k the string_types
-            # directive above catches this.
-            # see #2838
-            value = value.decode(self.dialect.encoding).replace("'", "''")
-            return "'%s'" % value
 
+        processor = type_._cached_literal_processor(self.dialect)
+        if processor:
+            return processor(value)
         else:
             raise NotImplementedError(
                         "Don't know how to literal-quote value %r" % value)
@@ -1599,7 +1594,7 @@ class SQLCompiler(Compiled):
 
     def visit_insert(self, insert_stmt, **kw):
         self.isinsert = True
-        colparams = self._get_colparams(insert_stmt)
+        colparams = self._get_colparams(insert_stmt, **kw)
 
         if not colparams and \
                 not self.dialect.supports_default_values and \
@@ -1732,7 +1727,7 @@ class SQLCompiler(Compiled):
         table_text = self.update_tables_clause(update_stmt, update_stmt.table,
                                                extra_froms, **kw)
 
-        colparams = self._get_colparams(update_stmt, extra_froms)
+        colparams = self._get_colparams(update_stmt, extra_froms, **kw)
 
         if update_stmt._hints:
             dialect_hints = dict([
@@ -1801,7 +1796,7 @@ class SQLCompiler(Compiled):
         bindparam._is_crud = True
         return bindparam._compiler_dispatch(self)
 
-    def _get_colparams(self, stmt, extra_tables=None):
+    def _get_colparams(self, stmt, extra_tables=None, **kw):
         """create a set of tuples representing column/string pairs for use
         in an INSERT or UPDATE statement.
 
@@ -1853,9 +1848,9 @@ class SQLCompiler(Compiled):
                     # add it to values() in an "as-is" state,
                     # coercing right side to bound param
                     if elements._is_literal(v):
-                        v = self.process(elements.BindParameter(None, v, type_=k.type))
+                        v = self.process(elements.BindParameter(None, v, type_=k.type), **kw)
                     else:
-                        v = self.process(v.self_group())
+                        v = self.process(v.self_group(), **kw)
 
                     values.append((k, v))
 
@@ -1903,7 +1898,7 @@ class SQLCompiler(Compiled):
                                 c, value, required=value is REQUIRED)
                         else:
                             self.postfetch.append(c)
-                            value = self.process(value.self_group())
+                            value = self.process(value.self_group(), **kw)
                         values.append((c, value))
             # determine tables which are actually
             # to be updated - process onupdate and
@@ -1915,7 +1910,7 @@ class SQLCompiler(Compiled):
                     elif c.onupdate is not None and not c.onupdate.is_sequence:
                         if c.onupdate.is_clause_element:
                             values.append(
-                                (c, self.process(c.onupdate.arg.self_group()))
+                                (c, self.process(c.onupdate.arg.self_group(), **kw))
                             )
                             self.postfetch.append(c)
                         else:
@@ -1941,14 +1936,14 @@ class SQLCompiler(Compiled):
                                     )
                 elif c.primary_key and implicit_returning:
                     self.returning.append(c)
-                    value = self.process(value.self_group())
+                    value = self.process(value.self_group(), **kw)
                 elif implicit_return_defaults and \
                     c in implicit_return_defaults:
                     self.returning.append(c)
-                    value = self.process(value.self_group())
+                    value = self.process(value.self_group(), **kw)
                 else:
                     self.postfetch.append(c)
-                    value = self.process(value.self_group())
+                    value = self.process(value.self_group(), **kw)
                 values.append((c, value))
 
             elif self.isinsert:
@@ -1966,13 +1961,13 @@ class SQLCompiler(Compiled):
                                 if self.dialect.supports_sequences and \
                                     (not c.default.optional or \
                                     not self.dialect.sequences_optional):
-                                    proc = self.process(c.default)
+                                    proc = self.process(c.default, **kw)
                                     values.append((c, proc))
                                 self.returning.append(c)
                             elif c.default.is_clause_element:
                                 values.append(
                                     (c,
-                                    self.process(c.default.arg.self_group()))
+                                    self.process(c.default.arg.self_group(), **kw))
                                 )
                                 self.returning.append(c)
                             else:
@@ -2000,7 +1995,7 @@ class SQLCompiler(Compiled):
                         if self.dialect.supports_sequences and \
                             (not c.default.optional or \
                             not self.dialect.sequences_optional):
-                            proc = self.process(c.default)
+                            proc = self.process(c.default, **kw)
                             values.append((c, proc))
                             if implicit_return_defaults and \
                                 c in implicit_return_defaults:
@@ -2009,7 +2004,7 @@ class SQLCompiler(Compiled):
                                 self.postfetch.append(c)
                     elif c.default.is_clause_element:
                         values.append(
-                            (c, self.process(c.default.arg.self_group()))
+                            (c, self.process(c.default.arg.self_group(), **kw))
                         )
 
                         if implicit_return_defaults and \
@@ -2037,7 +2032,7 @@ class SQLCompiler(Compiled):
                 if c.onupdate is not None and not c.onupdate.is_sequence:
                     if c.onupdate.is_clause_element:
                         values.append(
-                            (c, self.process(c.onupdate.arg.self_group()))
+                            (c, self.process(c.onupdate.arg.self_group(), **kw))
                         )
                         if implicit_return_defaults and \
                             c in implicit_return_defaults:
index 1d7dacb915820ce210b18273114de2abf86bcc44..01d9181200d12a1f63ee6e603a150e9648e015d5 100644 (file)
@@ -154,6 +154,12 @@ class String(Concatenable, TypeEngine):
         self.unicode_error = unicode_error
         self._warn_on_bytestring = _warn_on_bytestring
 
+    def literal_processor(self, dialect):
+        def process(value):
+            value = value.replace("'", "''")
+            return "'%s'" % value
+        return process
+
     def bind_processor(self, dialect):
         if self.convert_unicode or dialect.convert_unicode:
             if dialect.supports_unicode_binds and \
@@ -345,6 +351,11 @@ class Integer(_DateAffinity, TypeEngine):
     def python_type(self):
         return int
 
+    def literal_processor(self, dialect):
+        def process(value):
+            return str(value)
+        return process
+
     @util.memoized_property
     def _expression_adaptations(self):
         # TODO: need a dictionary object that will
@@ -481,6 +492,11 @@ class Numeric(_DateAffinity, TypeEngine):
     def get_dbapi_type(self, dbapi):
         return dbapi.NUMBER
 
+    def literal_processor(self, dialect):
+        def process(value):
+            return str(value)
+        return process
+
     @property
     def python_type(self):
         if self.asdecimal:
@@ -728,6 +744,12 @@ class _Binary(TypeEngine):
     def __init__(self, length=None):
         self.length = length
 
+    def literal_processor(self, dialect):
+        def process(value):
+            value = value.decode(self.dialect.encoding).replace("'", "''")
+            return "'%s'" % value
+        return process
+
     @property
     def python_type(self):
         return util.binary_type
@@ -1500,6 +1522,11 @@ class NullType(TypeEngine):
 
     _isnull = True
 
+    def literal_processor(self, dialect):
+        def process(value):
+            return "NULL"
+        return process
+
     class Comparator(TypeEngine.Comparator):
         def _adapt_expression(self, op, other_comparator):
             if isinstance(other_comparator, NullType.Comparator) or \
index 83b8ec570672e3d05503aaddfc4a284931095f74..698e17472ba6331a105456072f1e8efb1a000354 100644 (file)
@@ -75,6 +75,19 @@ class TypeEngine(Visitable):
     def copy_value(self, value):
         return value
 
+    def literal_processor(self, dialect):
+        """Return a conversion function for processing literal values that are
+        to be rendered directly without using binds.
+
+        This function is used when the compiler makes use of the
+        "literal_binds" flag, typically used in DDL generation as well
+        as in certain scenarios where backends don't accept bound parameters.
+
+        .. versionadded:: 0.9.0
+
+        """
+        return None
+
     def bind_processor(self, dialect):
         """Return a conversion function for processing bind values.
 
@@ -265,6 +278,16 @@ class TypeEngine(Visitable):
         except KeyError:
             return self._dialect_info(dialect)['impl']
 
+
+    def _cached_literal_processor(self, dialect):
+        """Return a dialect-specific literal processor for this type."""
+        try:
+            return dialect._type_memos[self]['literal']
+        except KeyError:
+            d = self._dialect_info(dialect)
+            d['literal'] = lp = d['impl'].literal_processor(dialect)
+            return lp
+
     def _cached_bind_processor(self, dialect):
         """Return a dialect-specific bind processor for this type."""
 
@@ -673,6 +696,22 @@ class TypeDecorator(TypeEngine):
         implementation."""
         return getattr(self.impl, key)
 
+    def process_literal_param(self, value, dialect):
+        """Receive a literal parameter value to be rendered inline within
+        a statement.
+
+        This method is used when the compiler renders a
+        literal value without using binds, typically within DDL
+        such as in the "server default" of a column or an expression
+        within a CHECK constraint.
+
+        The returned string will be rendered into the output string.
+
+        .. versionadded:: 0.9.0
+
+        """
+        raise NotImplementedError()
+
     def process_bind_param(self, value, dialect):
         """Receive a bound parameter value to be converted.
 
@@ -737,6 +776,40 @@ class TypeDecorator(TypeEngine):
         return self.__class__.process_bind_param.__code__ \
             is not TypeDecorator.process_bind_param.__code__
 
+    @util.memoized_property
+    def _has_literal_processor(self):
+        """memoized boolean, check if process_literal_param is implemented.
+
+
+        """
+
+        return self.__class__.process_literal_param.__code__ \
+            is not TypeDecorator.process_literal_param.__code__
+
+    def literal_processor(self, dialect):
+        """Provide a literal processing function for the given
+        :class:`.Dialect`.
+
+        Subclasses here will typically override :meth:`.TypeDecorator.process_literal_param`
+        instead of this method directly.
+
+        .. versionadded:: 0.9.0
+
+        """
+        if self._has_literal_processor:
+            process_param = self.process_literal_param
+            impl_processor = self.impl.literal_processor(dialect)
+            if impl_processor:
+                def process(value):
+                    return impl_processor(process_param(value, dialect))
+            else:
+                def process(value):
+                    return process_param(value, dialect)
+
+            return process
+        else:
+            return self.impl.literal_processor(dialect)
+
     def bind_processor(self, dialect):
         """Provide a bound value processing function for the
         given :class:`.Dialect`.
index 062fffb1819a48a1a24177aed42f01e3f3cfdbb8..0d43d0e0456edf81263dda8b8c9f29463557c123 100644 (file)
@@ -187,7 +187,8 @@ class AssertsCompiledSQL(object):
                         checkparams=None, dialect=None,
                         checkpositional=None,
                         use_default_dialect=False,
-                        allow_dialect_select=False):
+                        allow_dialect_select=False,
+                        literal_binds=False):
         if use_default_dialect:
             dialect = default.DefaultDialect()
         elif allow_dialect_select:
@@ -205,14 +206,22 @@ class AssertsCompiledSQL(object):
 
 
         kw = {}
+        compile_kwargs = {}
+
         if params is not None:
             kw['column_keys'] = list(params)
 
+        if literal_binds:
+            compile_kwargs['literal_binds'] = True
+
         if isinstance(clause, orm.Query):
             context = clause._compile_context()
             context.statement.use_labels = True
             clause = context.statement
 
+        if compile_kwargs:
+            kw['compile_kwargs'] = compile_kwargs
+
         c = clause.compile(dialect=dialect, **kw)
 
         param_str = repr(getattr(c, 'params', {}))
index d301dc69f1e179f45b36a2c431106bd3fe460104..7dc6ea40beb8780e22865fa73156ae51e1cf5e04 100644 (file)
@@ -295,6 +295,15 @@ class SuiteRequirements(Requirements):
         """Target driver must support some degree of non-ascii symbol names."""
         return exclusions.closed()
 
+    @property
+    def datetime_literals(self):
+        """target dialect supports rendering of a date, time, or datetime as a
+        literal string, e.g. via the TypeEngine.literal_processor() method.
+
+        """
+
+        return exclusions.closed()
+
     @property
     def datetime(self):
         """target dialect supports representation of Python
index 0de462eb7e7511ec5d4e3cff1d0b14d78e056b06..5523523aa1562ba1ec6cd1f167a618052fb060b1 100644 (file)
@@ -5,7 +5,7 @@ from ..assertions import eq_
 from ..config import requirements
 from sqlalchemy import Integer, Unicode, UnicodeText, select
 from sqlalchemy import Date, DateTime, Time, MetaData, String, \
-            Text, Numeric, Float
+            Text, Numeric, Float, literal
 from ..schema import Table, Column
 from ... import testing
 import decimal
@@ -13,7 +13,31 @@ import datetime
 from ...util import u
 from ... import util
 
-class _UnicodeFixture(object):
+
+class _LiteralRoundTripFixture(object):
+    @testing.provide_metadata
+    def _literal_round_trip(self, type_, input_, output):
+        """test literal rendering """
+
+        # for literal, we test the literal render in an INSERT
+        # into a typed column.  we can then SELECT it back as it's
+        # official type; ideally we'd be able to use CAST here
+        # but MySQL in particular can't CAST fully
+        t = Table('t', self.metadata, Column('x', type_))
+        t.create()
+
+        for value in input_:
+            ins = t.insert().values(x=literal(value)).compile(
+                            dialect=testing.db.dialect,
+                            compile_kwargs=dict(literal_binds=True)
+                        )
+            testing.db.execute(ins)
+
+        for row in t.select().execute():
+            assert row[0] in output
+
+
+class _UnicodeFixture(_LiteralRoundTripFixture):
     __requires__ = 'unicode_data',
 
     data = u("Alors vous imaginez ma surprise, au lever du jour, "\
@@ -87,6 +111,9 @@ class _UnicodeFixture(object):
                 ).first()
         eq_(row, (u(''),))
 
+    def test_literal(self):
+        self._literal_round_trip(self.datatype, [self.data], [self.data])
+
 
 class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
     __requires__ = 'unicode_data',
@@ -107,7 +134,7 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
     def test_empty_strings_text(self):
         self._test_empty_strings()
 
-class TextTest(fixtures.TablesTest):
+class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('text_table', metadata,
@@ -140,8 +167,18 @@ class TextTest(fixtures.TablesTest):
                 ).first()
         eq_(row, ('',))
 
+    def test_literal(self):
+        self._literal_round_trip(Text, ["some text"], ["some text"])
+
+    def test_literal_quoting(self):
+        data = '''some 'text' hey "hi there" that's text'''
+        self._literal_round_trip(Text, [data], [data])
+
+    def test_literal_backslashes(self):
+        data = r'backslash one \ backslash two \\ end'
+        self._literal_round_trip(Text, [data], [data])
 
-class StringTest(fixtures.TestBase):
+class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
     @requirements.unbounded_varchar
     def test_nolength_string(self):
         metadata = MetaData()
@@ -152,8 +189,19 @@ class StringTest(fixtures.TestBase):
         foo.create(config.db)
         foo.drop(config.db)
 
+    def test_literal(self):
+        self._literal_round_trip(String(40), ["some text"], ["some text"])
 
-class _DateFixture(object):
+    def test_literal_quoting(self):
+        data = '''some 'text' hey "hi there" that's text'''
+        self._literal_round_trip(String(40), [data], [data])
+
+    def test_literal_backslashes(self):
+        data = r'backslash one \ backslash two \\ end'
+        self._literal_round_trip(Text, [data], [data])
+
+
+class _DateFixture(_LiteralRoundTripFixture):
     compare = None
 
     @classmethod
@@ -198,6 +246,12 @@ class _DateFixture(object):
                 ).first()
         eq_(row, (None,))
 
+    @testing.requires.datetime_literals
+    def test_literal(self):
+        compare = self.compare or self.data
+        self._literal_round_trip(self.datatype, [self.data], [compare])
+
+
 
 class DateTimeTest(_DateFixture, fixtures.TablesTest):
     __requires__ = 'datetime',
@@ -247,7 +301,12 @@ class DateHistoricTest(_DateFixture, fixtures.TablesTest):
     datatype = Date
     data = datetime.date(1727, 4, 1)
 
-class NumericTest(fixtures.TestBase):
+
+class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
+    def test_literal(self):
+        self._literal_round_trip(Integer, [5], [5])
+
+class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
 
     @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
     @testing.provide_metadata
@@ -269,6 +328,30 @@ class NumericTest(fixtures.TestBase):
                 [str(x) for x in output],
             )
 
+
+    @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+    def test_render_literal_numeric(self):
+        self._literal_round_trip(
+            Numeric(precision=8, scale=4),
+            [15.7563, decimal.Decimal("15.7563")],
+            [decimal.Decimal("15.7563")],
+        )
+
+    @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+    def test_render_literal_numeric_asfloat(self):
+        self._literal_round_trip(
+            Numeric(precision=8, scale=4, asdecimal=False),
+            [15.7563, decimal.Decimal("15.7563")],
+            [15.7563],
+        )
+
+    def test_render_literal_float(self):
+        self._literal_round_trip(
+            Float(4),
+            [15.7563, decimal.Decimal("15.7563")],
+            [15.7563],
+        )
+
     def test_numeric_as_decimal(self):
         self._do_test(
             Numeric(precision=8, scale=4),
@@ -291,6 +374,7 @@ class NumericTest(fixtures.TestBase):
             [decimal.Decimal("15.7563"), None],
         )
 
+
     def test_float_as_float(self):
         self._do_test(
             Float(precision=8),
@@ -299,6 +383,7 @@ class NumericTest(fixtures.TestBase):
             filter_=lambda n: n is not None and round(n, 5) or None
         )
 
+
     @testing.requires.precision_numerics_general
     def test_precision_decimal(self):
         numbers = set([
@@ -313,6 +398,7 @@ class NumericTest(fixtures.TestBase):
             numbers,
         )
 
+
     @testing.requires.precision_numerics_enotation_large
     def test_enotation_decimal(self):
         """test exceedingly small decimals.
@@ -342,6 +428,7 @@ class NumericTest(fixtures.TestBase):
             numbers
         )
 
+
     @testing.requires.precision_numerics_enotation_large
     def test_enotation_decimal_large(self):
         """test exceedingly large decimals.
@@ -389,7 +476,7 @@ class NumericTest(fixtures.TestBase):
 
 __all__ = ('UnicodeVarcharTest', 'UnicodeTextTest',
             'DateTest', 'DateTimeTest', 'TextTest',
-            'NumericTest',
+            'NumericTest', 'IntegerTest',
             'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
             'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest',
             'DateHistoricTest', 'StringTest')
index 3a2154fc574dabc3857abef0a0161b8025bd5ac9..e64b67fcf965c91484db70b6138f9be3d7a44e78 100644 (file)
@@ -62,6 +62,7 @@ from .sql.sqltypes import (
     SMALLINT,
     SmallInteger,
     String,
+    STRINGTYPE,
     TEXT,
     TIME,
     TIMESTAMP,
index cd59e524925d0d803c2e26ee64c8e75b4ceaa762..e7728d6e0abee62132ffff13af2e4aceb7ca45c8 100644 (file)
@@ -431,6 +431,15 @@ class DefaultRequirements(SuiteRequirements):
         return fails_on_everything_except('postgresql', 'oracle', 'mssql',
                     'sybase')
 
+    @property
+    def datetime_literals(self):
+        """target dialect supports rendering of a date, time, or datetime as a
+        literal string, e.g. via the TypeEngine.literal_processor() method.
+
+        """
+
+        return fails_on_everything_except("sqlite")
+
     @property
     def datetime(self):
         """target dialect supports representation of Python
index d122aef6a8e1ae21a54a82d22df56f5b2fadb9e3..a2791ee29a6ce0711a7ab2e0e8a6fe7432fca64b 100644 (file)
@@ -273,6 +273,20 @@ class UserDefinedTest(fixtures.TablesTest, AssertsCompiledSQL):
             for col in row[3], row[4]:
                 assert isinstance(col, util.text_type)
 
+    def test_typedecorator_literal_render(self):
+        class MyType(types.TypeDecorator):
+            impl = String
+
+            def process_literal_param(self, value, dialect):
+                return "HI->%s<-THERE" % value
+
+        self.assert_compile(
+            select([literal("test", MyType)]),
+            "SELECT 'HI->test<-THERE' AS anon_1",
+            dialect='default',
+            literal_binds=True
+        )
+
     def test_typedecorator_impl(self):
         for impl_, exp, kw in [
             (Float, "FLOAT", {}),