From: Mike Bayer Date: Tue, 11 Jan 2011 20:22:46 +0000 (-0500) Subject: - A TypeDecorator of Integer can be used with a primary key X-Git-Tag: rel_0_7b1~83 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=67e0f356b2093fdc03303d50be1f89e75e847c7f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - A TypeDecorator of Integer can be used with a primary key column, and the "autoincrement" feature of various dialects as well as the "sqlite_autoincrement" flag will honor the underlying database type as being Integer-based. [ticket:2005] - Result-row processors are applied to pre-executed SQL defaults, as well as cursor.lastrowid, when determining the contents of result.inserted_primary_key. [ticket:2006] - Bind parameters present in the "columns clause" of a select are now auto-labeled like other "anonymous" clauses, which among other things allows their "type" to be meaningful when the row is fetched, as in result row processors. - TypeDecorator is present in the "sqlalchemy" import space. --- diff --git a/CHANGES b/CHANGES index 4826eb6ab0..1c2ea954ea 100644 --- a/CHANGES +++ b/CHANGES @@ -15,6 +15,24 @@ CHANGES definition, using strings as column names, as an alternative to the creation of the index outside of the Table. + - A TypeDecorator of Integer can be used with a primary key + column, and the "autoincrement" feature of various dialects + as well as the "sqlite_autoincrement" flag will honor + the underlying database type as being Integer-based. + [ticket:2005] + + - Result-row processors are applied to pre-executed SQL + defaults, as well as cursor.lastrowid, when determining + the contents of result.inserted_primary_key. + [ticket:2006] + + - Bind parameters present in the "columns clause" of a select + are now auto-labeled like other "anonymous" clauses, + which among other things allows their "type" to be meaningful + when the row is fetched, as in result row processors. + + - TypeDecorator is present in the "sqlalchemy" import space. + - mssql - the String/Unicode types, and their counterparts VARCHAR/ NVARCHAR, emit "max" as the length when no length is diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index e9976cd130..239b2b363c 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -83,6 +83,7 @@ from sqlalchemy.types import ( TIMESTAMP, Text, Time, + TypeDecorator, Unicode, UnicodeText, VARCHAR, diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index b9f6e0b3e3..4043cd6c3b 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -331,12 +331,13 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): class FBExecutionContext(default.DefaultExecutionContext): - def fire_sequence(self, seq): + def fire_sequence(self, seq, proc): """Get the next value from the sequence using ``gen_id()``.""" return self._execute_scalar( "SELECT gen_id(%s, 1) FROM rdb$database" % - self.dialect.identifier_preparer.format_sequence(seq) + self.dialect.identifier_preparer.format_sequence(seq), + proc ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e26d83f0aa..61dd99b85b 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1309,16 +1309,8 @@ class MySQLDDLCompiler(compiler.DDLCompiler): elif column.nullable and is_timestamp and default is None: colspec.append('NULL') - if column.primary_key and column.autoincrement: - try: - first = [c for c in column.table.primary_key.columns - if (c.autoincrement and - isinstance(c.type, sqltypes.Integer) and - not c.foreign_keys)].pop(0) - if column is first: - colspec.append('AUTO_INCREMENT') - except IndexError: - pass + if column is column.table._autoincrement_column: + colspec.append('AUTO_INCREMENT') return ' '.join(colspec) @@ -1335,7 +1327,8 @@ class MySQLDDLCompiler(compiler.DDLCompiler): arg = "'%s'" % arg.replace("\\", "\\\\").replace("'", "''") if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', - 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', 'DEFAULT_CHARSET', + 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', + 'DEFAULT_CHARSET', 'DEFAULT_COLLATE'): opt = opt.replace('_', ' ') diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index bacad37041..0b0622f845 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -593,10 +593,10 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): class OracleExecutionContext(default.DefaultExecutionContext): - def fire_sequence(self, seq): + def fire_sequence(self, seq, proc): return int(self._execute_scalar("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval FROM DUAL")) + ".nextval FROM DUAL"), proc) class OracleDialect(default.DefaultDialect): name = 'oracle' diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 31f699d2bb..84fd96edd1 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -506,15 +506,16 @@ class PGCompiler(compiler.SQLCompiler): class PGDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + type_affinity = column.type._type_affinity if column.primary_key and \ len(column.foreign_keys)==0 and \ column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and \ - not isinstance(column.type, sqltypes.SmallInteger) and \ + issubclass(type_affinity, sqltypes.Integer) and \ + not issubclass(type_affinity, sqltypes.SmallInteger) and \ (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): - if isinstance(column.type, sqltypes.BigInteger): + if issubclass(type_affinity, sqltypes.BigInteger): colspec += " BIGSERIAL" else: colspec += " SERIAL" @@ -680,21 +681,21 @@ class DropEnumType(schema._CreateDropBase): __visit_name__ = "drop_enum_type" class PGExecutionContext(default.DefaultExecutionContext): - def fire_sequence(self, seq): + def fire_sequence(self, seq, proc): if not seq.optional: return self._execute_scalar(("select nextval('%s')" % \ - self.dialect.identifier_preparer.format_sequence(seq))) + self.dialect.identifier_preparer.format_sequence(seq)), proc) else: return None - def get_insert_default(self, column): + def get_insert_default(self, column, proc): if column.primary_key: if (isinstance(column.server_default, schema.DefaultClause) and column.server_default.arg is not None): # pre-execute passive defaults on primary key columns return self._execute_scalar("select %s" % - column.server_default.arg) + column.server_default.arg, proc) elif column is column.table._autoincrement_column \ and (column.default is None or @@ -713,9 +714,9 @@ class PGExecutionContext(default.DefaultExecutionContext): exc = "select nextval('\"%s_%s_seq\"')" % \ (column.table.name, column.name) - return self._execute_scalar(exc) + return self._execute_scalar(exc, proc) - return super(PGExecutionContext, self).get_insert_default(column) + return super(PGExecutionContext, self).get_insert_default(column, proc) class PGDialect(default.DefaultDialect): name = 'postgresql' diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c52668762b..ac0fde8468 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -249,7 +249,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if column.primary_key and \ column.table.kwargs.get('sqlite_autoincrement', False) and \ len(column.table.primary_key.columns) == 1 and \ - isinstance(column.type, sqltypes.Integer) and \ + issubclass(column.type._type_affinity, sqltypes.Integer) and \ not column.foreign_keys: colspec += " PRIMARY KEY AUTOINCREMENT" @@ -263,7 +263,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): c = list(constraint)[0] if c.primary_key and \ c.table.kwargs.get('sqlite_autoincrement', False) and \ - isinstance(c.type, sqltypes.Integer) and \ + issubclass(c.type._type_affinity, sqltypes.Integer) and \ not c.foreign_keys: return None diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index eb48c29d61..3bdcad2ac8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1233,7 +1233,7 @@ class Connection(Connectable): self._handle_dbapi_exception(e, None, None, None, None) raise - ret = ctx._exec_default(default) + ret = ctx._exec_default(default, None) if self.should_close_with_result: self.close() return ret diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 76077778ce..eacbef8f9d 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -418,9 +418,9 @@ class DefaultExecutionContext(base.ExecutionContext): self.cursor = self.create_cursor() if self.isinsert or self.isupdate: - self.__process_defaults() self.postfetch_cols = self.compiled.postfetch self.prefetch_cols = self.compiled.prefetch + self.__process_defaults() processors = compiled._bind_processors @@ -532,7 +532,7 @@ class DefaultExecutionContext(base.ExecutionContext): else: return autocommit - def _execute_scalar(self, stmt): + def _execute_scalar(self, stmt, proc): """Execute a string statement on the current cursor, returning a scalar result. @@ -553,7 +553,11 @@ class DefaultExecutionContext(base.ExecutionContext): default_params = {} conn._cursor_execute(self.cursor, stmt, default_params) - return self.cursor.fetchone()[0] + r = self.cursor.fetchone()[0] + if proc: + return proc(r) + else: + return r @property def connection(self): @@ -623,8 +627,14 @@ class DefaultExecutionContext(base.ExecutionContext): table = self.compiled.statement.table lastrowid = self.get_lastrowid() self.inserted_primary_key = [ - c is table._autoincrement_column and lastrowid or v - for c, v in zip(table.primary_key, self.inserted_primary_key) + c is table._autoincrement_column and ( + proc and proc(lastrowid) + or lastrowid + ) or v + for c, v, proc in zip( + table.primary_key, + self.inserted_primary_key, + self.compiled._pk_processors) ] def _fetch_implicit_returning(self, resultproxy): @@ -688,9 +698,9 @@ class DefaultExecutionContext(base.ExecutionContext): self.root_connection._handle_dbapi_exception(e, None, None, None, self) raise - def _exec_default(self, default): + def _exec_default(self, default, proc): if default.is_sequence: - return self.fire_sequence(default) + return self.fire_sequence(default, proc) elif default.is_callable: return default.arg(self) elif default.is_clause_element: @@ -702,17 +712,17 @@ class DefaultExecutionContext(base.ExecutionContext): else: return default.arg - def get_insert_default(self, column): + def get_insert_default(self, column, proc): if column.default is None: return None else: - return self._exec_default(column.default) + return self._exec_default(column.default, proc) - def get_update_default(self, column): + def get_update_default(self, column, proc): if column.onupdate is None: return None else: - return self._exec_default(column.onupdate) + return self._exec_default(column.onupdate, proc) def __process_defaults(self): """Generate default values for compiled insert/update statements, @@ -726,7 +736,7 @@ class DefaultExecutionContext(base.ExecutionContext): # pre-determine scalar Python-side defaults # to avoid many calls of get_insert_default()/ # get_update_default() - for c in self.compiled.prefetch: + for c in self.prefetch_cols: if self.isinsert and c.default and c.default.is_scalar: scalar_defaults[c] = c.default.arg elif self.isupdate and c.onupdate and c.onupdate.is_scalar: @@ -734,13 +744,14 @@ class DefaultExecutionContext(base.ExecutionContext): for param in self.compiled_parameters: self.current_parameters = param - for c in self.compiled.prefetch: + for c, proc in zip(self.prefetch_cols, + self.compiled._prefetch_processors): if c in scalar_defaults: val = scalar_defaults[c] elif self.isinsert: - val = self.get_insert_default(c) + val = self.get_insert_default(c, proc) else: - val = self.get_update_default(c) + val = self.get_update_default(c, proc) if val is not None: param[c.key] = val del self.current_parameters @@ -748,11 +759,12 @@ class DefaultExecutionContext(base.ExecutionContext): self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] - for c in self.compiled.prefetch: + for c, proc in zip(self.compiled.prefetch, + self.compiled._prefetch_processors): if self.isinsert: - val = self.get_insert_default(c) + val = self.get_insert_default(c, proc) else: - val = self.get_update_default(c) + val = self.get_update_default(c, proc) if val is not None: compiled_parameters[c.key] = val diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index f4b2140613..a818745009 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -323,9 +323,10 @@ class Table(SchemaItem, expression.TableClause): def _autoincrement_column(self): for col in self.primary_key: if col.autoincrement and \ - isinstance(col.type, types.Integer) and \ + issubclass(col.type._type_affinity, types.Integer) and \ not col.foreign_keys and \ - isinstance(col.default, (type(None), Sequence)): + isinstance(col.default, (type(None), Sequence)) and \ + col.server_default is None: return col @@ -544,7 +545,7 @@ class Column(SchemaItem, expression.ColumnClause): The setting *only* has an effect for columns which are: - * Integer derived (i.e. INT, SMALLINT, BIGINT) + * Integer derived (i.e. INT, SMALLINT, BIGINT). * Part of the primary key diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 39d320edeb..92110ca2a4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -268,6 +268,20 @@ class SQLCompiler(engine.Compiled): if value is not None ) + @util.memoized_property + def _pk_processors(self): + return [ + col.type._cached_result_processor(self.dialect, None) + for col in self.statement.table.primary_key + ] + + @util.memoized_property + def _prefetch_processors(self): + return [ + col.type._cached_result_processor(self.dialect, None) + for col in self.prefetch + ] + def is_subquery(self): return len(self.stack) > 1 @@ -612,6 +626,7 @@ class SQLCompiler(engine.Compiled): ) self.binds[bindparam.key] = self.binds[name] = bindparam + return self.bindparam_string(name) def render_literal_bindparam(self, bindparam, **kw): @@ -732,8 +747,7 @@ class SQLCompiler(engine.Compiled): not isinstance(column.table, sql.Select): return _CompileLabel(column, sql._generated_label(column.name)) elif not isinstance(column, - (sql._UnaryExpression, sql._TextClause, - sql._BindParamClause)) \ + (sql._UnaryExpression, sql._TextClause)) \ and (not hasattr(column, 'name') or \ isinstance(column, sql.Function)): return _CompileLabel(column, column.anon_label) diff --git a/setup.cfg b/setup.cfg index bb8fe0543d..c2983fe410 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,3 +5,4 @@ tag_build = dev with-_sqlalchemy = true exclude = ^examples first-package-wins = true +where = test diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 6cc3271514..42d3cdcd1e 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -43,7 +43,7 @@ class CompileTest(TestBase, AssertsCompiledSQL): for expr, compile in [ ( select([literal("x"), literal("y")]), - "SELECT 'x', 'y'", + "SELECT 'x' AS anon_1, 'y' AS anon_2", ), ( select([t]).where(t.c.foo.in_(['x', 'y', 'z'])), diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index e3618f8411..11b0c004eb 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -647,3 +647,17 @@ class TestAutoIncrement(TestBase, AssertsCompiledSQL): 'CREATE TABLE noautoinctable (id INTEGER ' 'NOT NULL, x INTEGER, PRIMARY KEY (id))', dialect=sqlite.dialect()) + + def test_sqlite_autoincrement_int_affinity(self): + class MyInteger(TypeDecorator): + impl = Integer + table = Table( + 'autoinctable', + MetaData(), + Column('id', MyInteger, primary_key=True), + sqlite_autoincrement=True, + ) + self.assert_compile(schema.CreateTable(table), + 'CREATE TABLE autoinctable (id INTEGER NOT ' + 'NULL PRIMARY KEY AUTOINCREMENT)', + dialect=sqlite.dialect()) diff --git a/test/lib/requires.py b/test/lib/requires.py index 993a1546f5..b689250d27 100644 --- a/test/lib/requires.py +++ b/test/lib/requires.py @@ -202,12 +202,12 @@ def offset(fn): def returning(fn): return _chain_decorators_on( fn, - no_support('access', 'not supported by database'), - no_support('sqlite', 'not supported by database'), - no_support('mysql', 'not supported by database'), - no_support('maxdb', 'not supported by database'), - no_support('sybase', 'not supported by database'), - no_support('informix', 'not supported by database'), + no_support('access', "'returning' not supported by database"), + no_support('sqlite', "'returning' not supported by database"), + no_support('mysql', "'returning' not supported by database"), + no_support('maxdb', "'returning' not supported by database"), + no_support('sybase', "'returning' not supported by database"), + no_support('informix', "'returning' not supported by database"), ) def two_phase_transactions(fn): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index d63e41e90c..5a6d46b1b4 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -244,13 +244,13 @@ class SelectTest(TestBase, AssertsCompiledSQL): self.assert_compile( select([bindparam('a'), bindparam('b'), bindparam('c')]), - "SELECT :a, :b, :c" + "SELECT :a AS anon_1, :b AS anon_2, :c AS anon_3" , dialect=default.DefaultDialect(paramstyle='named') ) self.assert_compile( select([bindparam('a'), bindparam('b'), bindparam('c')]), - "SELECT ?, ?, ?" + "SELECT ? AS anon_1, ? AS anon_2, ? AS anon_3" , dialect=default.DefaultDialect(paramstyle='qmark'), ) @@ -1262,7 +1262,7 @@ class SelectTest(TestBase, AssertsCompiledSQL): self.assert_compile( select([literal("someliteral")]), - "SELECT 'someliteral'", + "SELECT 'someliteral' AS anon_1", dialect=dialect ) @@ -1298,7 +1298,7 @@ class SelectTest(TestBase, AssertsCompiledSQL): def test_literal(self): - self.assert_compile(select([literal('foo')]), "SELECT :param_1") + self.assert_compile(select([literal('foo')]), "SELECT :param_1 AS anon_1") self.assert_compile(select([literal("foo") + literal("bar")], from_obj=[table1]), "SELECT :param_1 || :param_2 AS anon_1 FROM mytable") diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 31759a7098..49aa8d3b37 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1,11 +1,12 @@ from test.lib.testing import eq_, assert_raises, assert_raises_message import datetime -from sqlalchemy import Sequence, Column, func from sqlalchemy.schema import CreateSequence, DropSequence -from sqlalchemy.sql import select, text +from sqlalchemy.sql import select, text, literal_column import sqlalchemy as sa from test.lib import testing, engines -from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean, exc +from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean, exc,\ + Sequence, Column, func, literal +from sqlalchemy.types import TypeDecorator from test.lib.schema import Table from test.lib.testing import eq_ from test.sql import _base @@ -690,3 +691,81 @@ class SequenceTest(testing.TestBase, testing.AssertsCompiledSQL): metadata.drop_all() +class SpecialTypePKTest(testing.TestBase): + """test process_result_value in conjunction with primary key columns. + + Also tests that "autoincrement" checks are against column.type._type_affinity, + rather than the class of "type" itself. + + """ + + @classmethod + def setup_class(cls): + class MyInteger(TypeDecorator): + impl = Integer + def process_bind_param(self, value, dialect): + return int(value[4:]) + + def process_result_value(self, value, dialect): + return "INT_%d" % value + + cls.MyInteger = MyInteger + + @testing.provide_metadata + def _run_test(self, *arg, **kw): + implicit_returning = kw.pop('implicit_returning', True) + kw['primary_key'] = True + t = Table('x', metadata, + Column('y', self.MyInteger, *arg, **kw), + Column('data', Integer), + implicit_returning=implicit_returning + ) + + t.create() + r = t.insert().values(data=5).execute() + eq_(r.inserted_primary_key, ['INT_1']) + r.close() + + eq_( + t.select().execute().first(), + ('INT_1', 5) + ) + + def test_plain(self): + # among other things, tests that autoincrement + # is enabled. + self._run_test() + + def test_literal_default_label(self): + self._run_test(default=literal("INT_1", type_=self.MyInteger).label('foo')) + + def test_literal_default_no_label(self): + self._run_test(default=literal("INT_1", type_=self.MyInteger)) + + def test_sequence(self): + self._run_test(Sequence('foo_seq')) + + @testing.fails_on('mysql', "Pending [ticket:2021]") + @testing.fails_on('sqlite', "Pending [ticket:2021]") + def test_server_default(self): + # note that the MySQL dialect has to not render AUTOINCREMENT on this one + self._run_test(server_default='1',) + + @testing.fails_on('mysql', "Pending [ticket:2021]") + @testing.fails_on('sqlite', "Pending [ticket:2021]") + def test_server_default_no_autoincrement(self): + self._run_test(server_default='1', autoincrement=False) + + def test_clause(self): + stmt = select([literal("INT_1", type_=self.MyInteger)]).as_scalar() + self._run_test(default=stmt) + + @testing.requires.returning + def test_no_implicit_returning(self): + self._run_test(implicit_returning=False) + + @testing.requires.returning + def test_server_default_no_implicit_returning(self): + self._run_test(server_default='1', autoincrement=False) + + diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 6a9055887d..cbf6e6e582 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -485,6 +485,27 @@ class QueryTest(TestBase): a_eq(prep(r"(\:that$other)"), "(:that$other)") a_eq(prep(r".\:that$ :other."), ".:that$ ?.") + def test_select_from_bindparam(self): + """Test result row processing when selecting from a plain bind param.""" + + class MyInteger(TypeDecorator): + impl = Integer + def process_bind_param(self, value, dialect): + return int(value[4:]) + + def process_result_value(self, value, dialect): + return "INT_%d" % value + + eq_( + testing.db.scalar(select([literal("INT_5", type_=MyInteger)])), + "INT_5" + ) + eq_( + testing.db.scalar(select([literal("INT_5", type_=MyInteger).label('foo')])), + "INT_5" + ) + + def test_delete(self): users.insert().execute(user_id = 7, user_name = 'jack') users.insert().execute(user_id = 8, user_name = 'fred')