return AUTOCOMMIT_RE.match(statement)
class MySQLCompiler(compiler.SQLCompiler):
- operators = compiler.SQLCompiler.operators.copy()
- operators.update({
- sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
- sql_operators.mod: '%%',
- sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
- })
- functions = compiler.SQLCompiler.functions.copy()
- functions.update ({
- sql_functions.random: 'rand%(expr)s',
- "utc_timestamp":"UTC_TIMESTAMP"
+ operators = util.update_copy(
+ compiler.SQLCompiler.operators,
+ {
+ sql_operators.concat_op: lambda x, y: "concat(%s, %s)" % (x, y),
+ sql_operators.match_op: lambda x, y: "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (x, y)
+ }
+ )
+
+ functions = util.update_copy(
+ compiler.SQLCompiler.functions,
+ {
+ sql_functions.random: 'rand%(expr)s',
+ "utc_timestamp":"UTC_TIMESTAMP"
})
def visit_typeclause(self, typeclause):
from sqlalchemy.dialects.mysql.base import MySQLDialect, MySQLExecutionContext, MySQLCompiler
from sqlalchemy.engine import base as engine_base, default
+from sqlalchemy.sql import operators as sql_operators
+
from sqlalchemy import exc, log, schema, sql, util
import re
return cursor.lastrowid
class MySQL_mysqldbCompiler(MySQLCompiler):
+ operators = util.update_copy(
+ MySQLCompiler.operators,
+ {
+ sql_operators.mod: '%%',
+ }
+ )
+
def post_process_text(self, text):
if '%%' in text:
util.warn("The SQLAlchemy mysql+mysqldb dialect now automatically escapes '%' in text() expressions to '%%'.")
supports_unicode_statements = False
default_paramstyle = 'format'
execution_ctx_cls = MySQL_mysqldbExecutionContext
- sql_compiler = MySQL_mysqldbCompiler
+ statement_compiler = MySQL_mysqldbCompiler
@classmethod
def dbapi(cls):
return tuple(version)
def _extract_error_code(self, exception):
- return exception.orig.args[0]
+ try:
+ return exception.orig.args[0]
+ except AttributeError:
+ return None
@engine_base.connection_memoize(('mysql', 'charset'))
def _detect_charset(self, connection):
return process
+colspecs = {
+ sqltypes.Interval:PGInterval
+}
+
+ischema_names = {
+ 'integer' : sqltypes.Integer,
+ 'bigint' : PGBigInteger,
+ 'smallint' : sqltypes.SmallInteger,
+ 'character varying' : sqltypes.String,
+ 'character' : sqltypes.CHAR,
+ 'text' : sqltypes.Text,
+ 'numeric' : sqltypes.Numeric,
+ 'float' : sqltypes.Float,
+ 'real' : sqltypes.Float,
+ 'inet': PGInet,
+ 'cidr': PGCidr,
+ 'macaddr': PGMacAddr,
+ 'double precision' : sqltypes.Float,
+ 'timestamp' : sqltypes.DateTime,
+ 'timestamp with time zone' : sqltypes.DateTime,
+ 'timestamp without time zone' : sqltypes.DateTime,
+ 'time with time zone' : sqltypes.Time,
+ 'time without time zone' : sqltypes.Time,
+ 'date' : sqltypes.Date,
+ 'time': sqltypes.Time,
+ 'bytea' : sqltypes.Binary,
+ 'boolean' : sqltypes.Boolean,
+ 'interval':PGInterval,
+}
class PGCompiler(compiler.SQLCompiler):
- operators = compiler.SQLCompiler.operators.copy()
- operators.update(
+
+ operators = util.update_copy(
+ compiler.SQLCompiler.operators,
{
sql_operators.mod : '%%',
+
sql_operators.ilike_op: lambda x, y, escape=None: '%s ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
sql_operators.notilike_op: lambda x, y, escape=None: '%s NOT ILIKE %s' % (x, y) + (escape and ' ESCAPE \'%s\'' % escape or ''),
sql_operators.match_op: lambda x, y: '%s @@ to_tsquery(%s)' % (x, y),
}
)
- functions = compiler.SQLCompiler.functions.copy()
- functions.update (
+ functions = util.update_copy(
+ compiler.SQLCompiler.functions,
{
'TIMESTAMP':lambda x:'TIMESTAMP %s' % x,
}
)
+ def post_process_text(self, text):
+ if '%%' in text:
+ util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
+ return text.replace('%', '%%')
+
def visit_sequence(self, seq):
if seq.optional:
return None
else:
return "nextval('%s')" % self.preparer.format_sequence(seq)
- def post_process_text(self, text):
- if '%%' in text:
- util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
- return text.replace('%', '%%')
-
def limit_clause(self, select):
text = ""
if select._limit is not None:
supports_default_values = True
supports_empty_insert = False
default_paramstyle = 'pyformat'
-
+ ischema_names = ischema_names
+ colspecs = colspecs
+
statement_compiler = PGCompiler
ddl_compiler = PGDDLCompiler
type_compiler = PGTypeCompiler
raise AssertionError("Could not determine version from string '%s'" % v)
return tuple([int(x) for x in m.group(1, 2, 3)])
+ def type_descriptor(self, typeobj):
+ return sqltypes.adapt_type(typeobj, self.colspecs)
+
def reflecttable(self, connection, table, include_columns):
preparer = self.identifier_preparer
if table.schema is not None:
--- /dev/null
+"""Support for the PostgreSQL database via the pg8000.
+
+Connecting
+----------
+
+URLs are of the form `postgres+pg8000://user@password@host:port/dbname[?key=value&key=value...]`.
+
+Unicode
+-------
+
+Unicode data which contains non-ascii characters don't seem to be supported yet. non-ascii
+schema identifiers though *are* supported, if you set the client_encoding=utf8 in the postgresql.conf
+file.
+
+Interval
+--------
+
+Passing data from/to the Interval type is not supported as of yet.
+
+"""
+
+import decimal, random, re, string
+
+from sqlalchemy import sql, schema, exc, util
+from sqlalchemy.engine import base, default
+from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import operators as sql_operators
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \
+ PGBigInteger, PGInterval
+
+class PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect):
+ if self.asdecimal:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return float(value)
+ else:
+ return value
+ return process
+
+class Postgres_pg8000ExecutionContext(default.DefaultExecutionContext):
+ pass
+
+class Postgres_pg8000(PGDialect):
+ driver = 'pg8000'
+
+ supports_unicode_statements = False #True
+
+ # this one doesn't matter, cant pass non-ascii through
+ # pending further investigation
+ supports_unicode_binds = False #True
+
+ default_paramstyle = 'format'
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = Postgres_pg8000ExecutionContext
+
+ @classmethod
+ def dbapi(cls):
+ return __import__('pg8000').dbapi
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username='user')
+ if 'port' in opts:
+ opts['port'] = int(opts['port'])
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e):
+ return "connection is closed" in e
+
+dialect = Postgres_pg8000
from sqlalchemy.sql import compiler, expression
from sqlalchemy.sql import operators as sql_operators
from sqlalchemy import types as sqltypes
-from sqlalchemy.dialects.postgres.base import PGDialect, PGInet, PGCidr, PGMacAddr, PGArray, \
- PGBigInteger, PGInterval
+from sqlalchemy.dialects.postgres.base import PGDialect, PGCompiler, PGInet, PGCidr, PGMacAddr, PGArray, \
+ PGBigInteger, PGInterval, colspecs
class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
return process
-colspecs = {
+colspecs = PGDialect.colspecs.copy()
+colspecs.update({
sqltypes.Numeric : PGNumeric,
sqltypes.Float: sqltypes.Float, # prevents PGNumeric from being used
-}
-
-ischema_names = {
- 'integer' : sqltypes.Integer,
- 'bigint' : PGBigInteger,
- 'smallint' : sqltypes.SmallInteger,
- 'character varying' : sqltypes.String,
- 'character' : sqltypes.CHAR,
- 'text' : sqltypes.Text,
- 'numeric' : PGNumeric,
- 'float' : sqltypes.Float,
- 'real' : sqltypes.Float,
- 'inet': PGInet,
- 'cidr': PGCidr,
- 'macaddr': PGMacAddr,
- 'double precision' : sqltypes.Float,
- 'timestamp' : sqltypes.DateTime,
- 'timestamp with time zone' : sqltypes.DateTime,
- 'timestamp without time zone' : sqltypes.DateTime,
- 'time with time zone' : sqltypes.Time,
- 'time without time zone' : sqltypes.Time,
- 'date' : sqltypes.Date,
- 'time': sqltypes.Time,
- 'bytea' : sqltypes.Binary,
- 'boolean' : sqltypes.Boolean,
- 'interval':PGInterval,
-}
+})
# TODO: filter out 'FOR UPDATE' statements
SERVER_SIDE_CURSOR_RE = re.compile(
else:
return base.ResultProxy(self)
+class Postgres_psycopg2Compiler(PGCompiler):
+ operators = util.update_copy(
+ PGCompiler.operators,
+ {
+ sql_operators.mod : '%%',
+ }
+ )
+
+ def post_process_text(self, text):
+ if '%%' in text:
+ util.warn("The SQLAlchemy postgres dialect now automatically escapes '%' in text() expressions to '%%'.")
+ return text.replace('%', '%%')
+
class Postgres_psycopg2(PGDialect):
driver = 'psycopg2'
supports_unicode_statements = False
default_paramstyle = 'pyformat'
supports_sane_multi_rowcount = False
execution_ctx_cls = Postgres_psycopg2ExecutionContext
- ischema_names = ischema_names
+ statement_compiler = Postgres_psycopg2Compiler
def __init__(self, server_side_cursors=False, **kwargs):
PGDialect.__init__(self, **kwargs)
return value and True or False
return process
+colspecs = {
+ sqltypes.Boolean: SLBoolean,
+ sqltypes.Date: SLDate,
+ sqltypes.DateTime: SLDateTime,
+ sqltypes.Float: SLFloat,
+ sqltypes.Numeric: SLNumeric,
+ sqltypes.Time: SLTime,
+}
+
+ischema_names = {
+ 'BLOB': sqltypes.Binary,
+ 'BOOL': sqltypes.Boolean,
+ 'BOOLEAN': sqltypes.Boolean,
+ 'CHAR': sqltypes.CHAR,
+ 'DATE': sqltypes.Date,
+ 'DATETIME': sqltypes.DateTime,
+ 'DECIMAL': sqltypes.Numeric,
+ 'FLOAT': sqltypes.Numeric,
+ 'INT': sqltypes.Integer,
+ 'INTEGER': sqltypes.Integer,
+ 'NUMERIC': sqltypes.Numeric,
+ 'REAL': sqltypes.Numeric,
+ 'SMALLINT': sqltypes.SmallInteger,
+ 'TEXT': sqltypes.Text,
+ 'TIME': sqltypes.Time,
+ 'TIMESTAMP': sqltypes.DateTime,
+ 'VARCHAR': sqltypes.String,
+}
+
+
class SQLiteCompiler(compiler.SQLCompiler):
functions = compiler.SQLCompiler.functions.copy()
name = 'sqlite'
supports_alter = False
supports_unicode_statements = True
+ supports_unicode_binds = True
supports_default_values = True
supports_empty_insert = False
supports_cast = True
ddl_compiler = SQLiteDDLCompiler
type_compiler = SQLiteTypeCompiler
preparer = SQLiteIdentifierPreparer
+ ischema_names = ischema_names
+
+ def type_descriptor(self, typeobj):
+ return sqltypes.adapt_type(typeobj, colspecs)
def table_names(self, connection, schema):
if schema is not None:
"""
-from sqlalchemy.dialects.sqlite.base import SLNumeric, SLFloat, SQLiteDialect, SLBoolean, SLDate, SLDateTime, SLTime
+from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy import schema, exc, pool
from sqlalchemy.engine import default
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from types import NoneType
-class SLUnicodeMixin(object):
- def bind_processor(self, dialect):
- if self.convert_unicode or dialect.convert_unicode:
- if self.assert_unicode is None:
- assert_unicode = dialect.assert_unicode
- else:
- assert_unicode = self.assert_unicode
-
- if not assert_unicode:
- return None
-
- def process(value):
- if not isinstance(value, (unicode, NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
- return value
- else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
- return process
- else:
- return None
-
- def result_processor(self, dialect):
- return None
-
-class SLText(SLUnicodeMixin, sqltypes.Text):
- pass
-
-class SLString(SLUnicodeMixin, sqltypes.String):
- pass
-
-class SLChar(SLUnicodeMixin, sqltypes.CHAR):
- pass
-
-
-colspecs = {
- sqltypes.Boolean: SLBoolean,
- sqltypes.CHAR: SLChar,
- sqltypes.Date: SLDate,
- sqltypes.DateTime: SLDateTime,
- sqltypes.Float: SLFloat,
- sqltypes.NCHAR: SLChar,
- sqltypes.Numeric: SLNumeric,
- sqltypes.String: SLString,
- sqltypes.Text: SLText,
- sqltypes.Time: SLTime,
-}
-
-ischema_names = {
- 'BLOB': sqltypes.Binary,
- 'BOOL': SLBoolean,
- 'BOOLEAN': SLBoolean,
- 'CHAR': SLChar,
- 'DATE': SLDate,
- 'DATETIME': SLDateTime,
- 'DECIMAL': SLNumeric,
- 'FLOAT': SLNumeric,
- 'INT': sqltypes.Integer,
- 'INTEGER': sqltypes.Integer,
- 'NUMERIC': SLNumeric,
- 'REAL': SLNumeric,
- 'SMALLINT': sqltypes.SmallInteger,
- 'TEXT': SLText,
- 'TIME': SLTime,
- 'TIMESTAMP': SLDateTime,
- 'VARCHAR': SLString,
-}
-
-
class SQLite_pysqliteExecutionContext(default.DefaultExecutionContext):
def post_exec(self):
if self.isinsert and not self.executemany:
poolclass = pool.SingletonThreadPool
execution_ctx_cls = SQLite_pysqliteExecutionContext
driver = 'pysqlite'
- ischema_names = ischema_names
def __init__(self, **kwargs):
SQLiteDialect.__init__(self, **kwargs)
return ([filename], opts)
- def type_descriptor(self, typeobj):
- return sqltypes.adapt_type(typeobj, colspecs)
-
def is_disconnect(self, e):
return isinstance(e, self.dbapi.ProgrammingError) and "Cannot operate on a closed database." in str(e)
supports_unicode_statements
Indicate whether the DB-API can receive SQL statements as Python unicode strings
+ supports_unicode_binds
+ Indicate whether the DB-API can receive string bind parameters as Python unicode strings
+
supports_sane_rowcount
Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements.
supports_sequences = False
sequences_optional = False
supports_unicode_statements = False
+ supports_unicode_binds = False
+
max_identifier_length = 9999
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
raise NotImplementedError("Can't generate DDL for the null type")
def visit_type_decorator(self, type_):
- return self.process(type_.dialect_impl(self.dialect).impl)
+ return self.process(type_.type_engine(self.dialect))
def visit_user_defined(self, type_):
return type_.get_col_spec()
"""
return op
+ def get_search_list(self):
+ """return a list of classes to test for a match
+ when adapting this type to a dialect-specific type.
+
+ """
+
+ return self.__class__.__mro__[0:-1]
+
def __repr__(self):
return "%s(%s)" % (
self.__class__.__name__,
def adapt(self, cls):
return cls()
- def get_search_list(self):
- """return a list of classes to test for a match
- when adapting this type to a dialect-specific type.
-
- """
-
- return self.__class__.__mro__[0:-1]
-
class UserDefinedType(TypeEngine):
"""Base for user defined types.
raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
self.impl = self.__class__.impl(*args, **kwargs)
- def dialect_impl(self, dialect, **kwargs):
+ def dialect_impl(self, dialect):
try:
return self._impl_dict[dialect]
except AttributeError:
except KeyError:
pass
+ # adapt the TypeDecorator first, in
+ # the case that the dialect maps the TD
+ # to one of its native types (i.e. PGInterval)
+ adapted = dialect.type_descriptor(self)
+ if adapted is not self:
+ self._impl_dict[dialect] = adapted
+ return adapted
+
+ # otherwise adapt the impl type, link
+ # to a copy of this TypeDecorator and return
+ # that.
typedesc = self.load_dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
self._impl_dict[dialect] = tt
return tt
+ def type_engine(self, dialect):
+ impl = self.dialect_impl(dialect)
+ if not isinstance(impl, TypeDecorator):
+ return impl
+ else:
+ return impl.impl
+
def load_dialect_impl(self, dialect):
"""Loads the dialect-specific implementation of this type.
by default calls dialect.type_descriptor(self.impl), but
can be overridden to provide different behavior.
+
"""
-
if isinstance(self.impl, TypeDecorator):
return self.impl.dialect_impl(dialect)
else:
assert_unicode = dialect.assert_unicode
else:
assert_unicode = self.assert_unicode
- def process(value):
- if isinstance(value, unicode):
- return value.encode(dialect.encoding)
- elif assert_unicode and not isinstance(value, (unicode, NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
+
+ if dialect.supports_unicode_binds and assert_unicode:
+ def process(value):
+ if not isinstance(value, (unicode, NoneType)):
+ if assert_unicode == 'warn':
+ util.warn("Unicode type received non-unicode bind "
+ "param value %r" % value)
+ return value
+ else:
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+ else:
return value
+ elif dialect.supports_unicode_binds:
+ return None
+ else:
+ def process(value):
+ if isinstance(value, unicode):
+ return value.encode(dialect.encoding)
+ elif assert_unicode and not isinstance(value, (unicode, NoneType)):
+ if assert_unicode == 'warn':
+ util.warn("Unicode type received non-unicode bind "
+ "param value %r" % value)
+ return value
+ else:
+ raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
+ return value
return process
else:
return None
"""
__visit_name__ = 'text'
-
- def dialect_impl(self, dialect, **kwargs):
- return TypeEngine.dialect_impl(self, dialect, **kwargs)
class Unicode(String):
"""A variable length Unicode string.
"""
- impl = TypeEngine
-
- def __init__(self):
- super(Interval, self).__init__()
- import sqlalchemy.dialects.postgres.base as pg
- self.__supported = {pg.PGDialect:pg.PGInterval}
- del pg
-
- def load_dialect_impl(self, dialect):
- if dialect.__class__ in self.__supported:
- return self.__supported[dialect.__class__]()
- else:
- return dialect.type_descriptor(DateTime)
+ impl = DateTime
def process_bind_param(self, value, dialect):
- if dialect.__class__ in self.__supported:
- return value
- else:
- if value is None:
- return None
- return dt.datetime.utcfromtimestamp(0) + value
+ if value is None:
+ return None
+ return dt.datetime.utcfromtimestamp(0) + value
def process_result_value(self, value, dialect):
- if dialect.__class__ in self.__supported:
- return value
- else:
- if value is None:
- return None
- return value - dt.datetime.utcfromtimestamp(0)
+ if value is None:
+ return None
+ return value - dt.datetime.utcfromtimestamp(0)
class FLOAT(Float):
"""The SQL FLOAT type."""
def decode_slice(slc):
return (slc.start, slc.stop, slc.step)
+def update_copy(d, _new=None, **kw):
+ """Copy the given dict and update with the given values."""
+
+ d = d.copy()
+ if _new:
+ d.update(_new)
+ d.update(**kw)
+ return d
+
def flatten_iterator(x):
"""Given an iterator of which further sub-elements may also be
iterators, flatten the sub-elements into a single iterator.
meta.drop_all()
engine.dispose()
- @testing.fails_on('mysql', 'FIXME: unknown')
+ @testing.fails_on('+mysqldb', "Buffers the result set and doesn't check for connection close")
+ @testing.fails_on('+pg8000', "Buffers the result set and doesn't check for connection close")
def test_invalidate_on_results(self):
conn = engine.connect()
engine.test_shutdown()
try:
- result.fetchone()
+ print "ghost result: %r" % result.fetchone()
assert False
except tsa.exc.DBAPIError, e:
if not e.connection_invalidated:
eq_(expr.execute().fetchall(), result)
+ @testing.fails_on("+pg8000", "can't interpret result column from '%%'")
@testing.emits_warning('.*now automatically escapes.*')
def test_percents_in_text(self):
for expr, result in (
assert unicode_table.c.unicode_varchar.type.length == 250
rawdata = 'Alors vous imaginez ma surprise, au lever du jour, quand une dr\xc3\xb4le de petit voix m\xe2\x80\x99a r\xc3\xa9veill\xc3\xa9. Elle disait: \xc2\xab S\xe2\x80\x99il vous pla\xc3\xaet\xe2\x80\xa6 dessine-moi un mouton! \xc2\xbb\n'
unicodedata = rawdata.decode('utf-8')
- if testing.against('sqlite'):
+
+ if testing.against('sqlite', '>' '2.4'):
rawdata = "something"
unicode_table.insert().execute(unicode_varchar=unicodedata,
x = unicode_table.select().execute().fetchone()
self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata)
self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata)
+
if isinstance(x['plain_varchar'], unicode):
# SQLLite and MSSQL return non-unicode data as unicode
self.assert_(testing.against('sqlite', '+pyodbc'))
if not testing.against('sqlite'):
self.assert_(x['plain_varchar'] == unicodedata)
- print "it's %s!" % testing.db.name
else:
self.assert_(not isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == rawdata)
def tearDownAll(self):
metadata.drop_all()
+ @testing.fails_on("+pg8000", "Not yet known how to pass values of the INTERVAL type")
def test_roundtrip(self):
delta = datetime.datetime(2006, 10, 5) - datetime.datetime(2005, 8, 17)
interval_table.insert().execute(interval=delta)
fn,
no_support('access', 'not supported by database'),
no_support('firebird', 'no SA implementation'),
+ no_support('+pg8000', 'FIXME: not sure how to accomplish'),
no_support('maxdb', 'not supported by database'),
no_support('mssql', 'FIXME: guessing, needs confirmation'),
no_support('oracle', 'no SA implementation'),