From 09553dc90f4a95b314994b48068b046de1413104 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 28 Jan 2012 15:20:21 -0500 Subject: [PATCH] - [feature] Dialect-specific compilers now raise CompileException for all type/statement compilation issues, instead of InvalidRequestError or ArgumentError. The DDL for CREATE TABLE will re-raise CompileExceptions to include table/column information for the problematic column. [ticket:2361] --- CHANGES | 7 ++++ lib/sqlalchemy/dialects/maxdb/base.py | 2 +- lib/sqlalchemy/dialects/mssql/base.py | 4 +- lib/sqlalchemy/dialects/mysql/base.py | 4 +- lib/sqlalchemy/dialects/postgresql/base.py | 2 +- lib/sqlalchemy/dialects/sqlite/base.py | 2 +- lib/sqlalchemy/dialects/sybase/base.py | 2 +- lib/sqlalchemy/sql/compiler.py | 43 +++++++++++++++------- test/dialect/test_mysql.py | 19 +++++++++- test/dialect/test_postgresql.py | 4 +- test/lib/testing.py | 4 +- test/sql/test_compiler.py | 39 +++++++++++++++++++- 12 files changed, 104 insertions(+), 28 deletions(-) diff --git a/CHANGES b/CHANGES index 9c86e50a81..0eeb2e6064 100644 --- a/CHANGES +++ b/CHANGES @@ -84,6 +84,13 @@ CHANGES constructs to sqlalchemy.sql namespace, though not part of __all__ as of yet. + - [feature] Dialect-specific compilers now raise + CompileException for all type/statement compilation + issues, instead of InvalidRequestError or ArgumentError. + The DDL for CREATE TABLE will re-raise + CompileExceptions to include table/column information + for the problematic column. [ticket:2361] + - [bug] Fixed issue where the "required" exception would not be raised for bindparam() with required=True, if the statement were given no parameters at all. diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index 027efbb899..ce3aaaa1e7 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -635,7 +635,7 @@ class MaxDBCompiler(compiler.SQLCompiler): # LIMIT. Right? Other dialects seem to get away with # dropping order. if select._limit: - raise exc.InvalidRequestError( + raise exc.CompileError( "MaxDB does not support ORDER BY in subqueries") else: return "" diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 4d7dd1c582..f7c94aabc2 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -804,7 +804,7 @@ class MSSQLCompiler(compiler.SQLCompiler): # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.process(select._order_by_clause) if not orderby: - raise exc.InvalidRequestError('MSSQL requires an order_by when ' + raise exc.CompileError('MSSQL requires an order_by when ' 'using an offset.') _offset = select._offset @@ -1029,7 +1029,7 @@ class MSDDLCompiler(compiler.DDLCompiler): colspec += " NULL" if column.table is None: - raise exc.InvalidRequestError( + raise exc.CompileError( "mssql requires Table-bound columns " "in order to generate DDL") diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index b6982c6c3d..6aa250d2d8 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1651,7 +1651,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if type_.length: return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: - raise exc.InvalidRequestError( + raise exc.CompileError( "VARCHAR requires a length on dialect %s" % self.dialect.name) @@ -1667,7 +1667,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if type_.length: return self._extend_string(type_, {'national':True}, "VARCHAR(%(length)s)" % {'length': type_.length}) else: - raise exc.InvalidRequestError( + raise exc.CompileError( "NVARCHAR requires a length on dialect %s" % self.dialect.name) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index f2510744f7..69c11d80fa 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -874,7 +874,7 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): def format_type(self, type_, use_schema=True): if not type_.name: - raise exc.ArgumentError("Postgresql ENUM type requires a name.") + raise exc.CompileError("Postgresql ENUM type requires a name.") name = self.quote(type_.name, type_.quote) if not self.omit_schema and use_schema and type_.schema is not None: diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 06c41b2eef..08a5204937 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -322,7 +322,7 @@ class SQLiteCompiler(compiler.SQLCompiler): return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( self.extract_map[extract.field], self.process(extract.expr, **kw)) except KeyError: - raise exc.ArgumentError( + raise exc.CompileError( "%s is not a valid extract argument." % extract.field) def limit_clause(self, select): diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 3c4706043b..4b8cc08bed 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -321,7 +321,7 @@ class SybaseDDLCompiler(compiler.DDLCompiler): self.dialect.type_compiler.process(column.type) if column.table is None: - raise exc.InvalidRequestError( + raise exc.CompileError( "The Sybase dialect requires Table-bound " "columns in order to generate DDL") seq_col = column.table._autoincrement_column diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 93e2473d94..2690dd896d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -23,6 +23,7 @@ To generate user-defined SQL strings, see """ import re +import sys from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions, util as sql_util, \ visitors @@ -1379,19 +1380,35 @@ class DDLCompiler(engine.Compiled): # if only one primary key, specify it along with the column first_pk = False for column in table.columns: - text += separator - separator = ", \n" - text += "\t" + self.get_column_specification( - column, - first_pk=column.primary_key and \ - not first_pk - ) - if column.primary_key: - first_pk = True - const = " ".join(self.process(constraint) \ - for constraint in column.constraints) - if const: - text += " " + const + try: + text += separator + separator = ", \n" + text += "\t" + self.get_column_specification( + column, + first_pk=column.primary_key and \ + not first_pk + ) + if column.primary_key: + first_pk = True + const = " ".join(self.process(constraint) \ + for constraint in column.constraints) + if const: + text += " " + const + except exc.CompileError, ce: + # Py3K + #raise exc.CompileError("(in table '%s', column '%s'): %s" + # % ( + # table.description, + # column.name, + # ce.args[0] + # )) from ce + # Py2K + raise exc.CompileError("(in table '%s', column '%s'): %s" + % ( + table.description, + column.name, + ce.args[0] + )), None, sys.exc_info()[2] const = self.create_table_constraints(table) if const: diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py index 51b4062fcb..acb4aa5e45 100644 --- a/test/dialect/test_mysql.py +++ b/test/dialect/test_mysql.py @@ -1,6 +1,6 @@ # coding: utf-8 -from test.lib.testing import eq_, assert_raises +from test.lib.testing import eq_, assert_raises, assert_raises_message # Py2K import sets @@ -1185,7 +1185,22 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): Unicode(), ): type_ = sqltypes.to_instance(type_) - assert_raises(exc.InvalidRequestError, type_.compile, dialect=mysql.dialect()) + assert_raises_message( + exc.CompileError, + "VARCHAR requires a length on dialect mysql", + type_.compile, + dialect=mysql.dialect()) + + t1 = Table('sometable', MetaData(), + Column('somecolumn', type_) + ) + assert_raises_message( + exc.CompileError, + r"\(in table 'sometable', column 'somecolumn'\)\: " + r"(?:N)?VARCHAR requires a length on dialect mysql", + schema.CreateTable(t1).compile, + dialect=mysql.dialect() + ) def test_update_limit(self): t = sql.table('t', sql.column('col1'), sql.column('col2')) diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 7279508ba9..769f18ce9a 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -401,8 +401,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL): def test_name_required(self): metadata = MetaData(testing.db) etype = Enum('four', 'five', 'six', metadata=metadata) - assert_raises(exc.ArgumentError, etype.create) - assert_raises(exc.ArgumentError, etype.compile, + assert_raises(exc.CompileError, etype.create) + assert_raises(exc.CompileError, etype.compile, dialect=postgresql.dialect()) @testing.fails_on('postgresql+zxjdbc', diff --git a/test/lib/testing.py b/test/lib/testing.py index a84c5a7ae8..e30603f6e6 100644 --- a/test/lib/testing.py +++ b/test/lib/testing.py @@ -522,8 +522,8 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): callable_(*args, **kwargs) assert False, "Callable did not raise an exception" except except_cls, e: - assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) - print str(e) + assert re.search(msg, unicode(e)), u"%r !~ %s" % (msg, e) + print unicode(e).encode('utf-8') def fail(msg): assert False, msg diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index b84a566d56..6e67431b45 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1,12 +1,15 @@ +#! coding:utf-8 + from test.lib.testing import eq_, assert_raises, assert_raises_message import datetime, re, operator, decimal from sqlalchemy import * -from sqlalchemy import exc, sql, util +from sqlalchemy import exc, sql, util, types, schema from sqlalchemy.sql import table, column, label, compiler from sqlalchemy.sql.expression import ClauseList, _literal_as_text from sqlalchemy.engine import default from sqlalchemy.databases import * from test.lib import * +from sqlalchemy.ext.compiler import compiles table1 = table('mytable', column('myid', Integer), @@ -2812,6 +2815,40 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE foo SET id=:id, foo_id=:foo_id WHERE foo.id = :foo_id_1" ) +class DDLTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = 'default' + + def _illegal_type_fixture(self): + class MyType(types.TypeEngine): + pass + @compiles(MyType) + def compile(element, compiler, **kw): + raise exc.CompileError("Couldn't compile type") + return MyType + + def test_reraise_of_column_spec_issue(self): + MyType = self._illegal_type_fixture() + t1 = Table('t', MetaData(), + Column('x', MyType()) + ) + assert_raises_message( + exc.CompileError, + r"\(in table 't', column 'x'\): Couldn't compile type", + schema.CreateTable(t1).compile + ) + + def test_reraise_of_column_spec_issue_unicode(self): + MyType = self._illegal_type_fixture() + t1 = Table('t', MetaData(), + Column(u'méil', MyType()) + ) + assert_raises_message( + exc.CompileError, + ur"\(in table 't', column 'méil'\): Couldn't compile type", + schema.CreateTable(t1).compile + ) + + class InlineDefaultTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' -- 2.47.2