From 6378c347994c902f7d4e65e54f2b76d01ce603d2 Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Tue, 23 Oct 2007 07:38:07 +0000 Subject: [PATCH] - Added initial version of MaxDB dialect. - All optional test Sequences are now optional=True --- lib/sqlalchemy/databases/__init__.py | 6 +- lib/sqlalchemy/databases/maxdb.py | 1083 ++++++++++++++++++++++++++ test/dialect/maxdb.py | 92 +++ test/engine/execute.py | 2 +- test/engine/reflection.py | 25 +- test/engine/transaction.py | 12 +- test/orm/assorted_eager.py | 10 +- test/orm/eager_relations.py | 6 +- test/orm/entity.py | 4 +- test/orm/inheritance/basic.py | 21 +- test/orm/inheritance/manytomany.py | 65 +- test/orm/lazy_relations.py | 2 +- test/orm/query.py | 28 +- test/orm/unitofwork.py | 12 +- test/sql/case_statement.py | 1 + test/sql/query.py | 28 +- test/sql/rowcount.py | 3 + test/sql/testtypes.py | 26 +- test/sql/unicode.py | 10 +- 19 files changed, 1320 insertions(+), 116 deletions(-) create mode 100644 lib/sqlalchemy/databases/maxdb.py create mode 100644 test/dialect/maxdb.py diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py index 87b1e85105..c4c60c2b45 100644 --- a/lib/sqlalchemy/databases/__init__.py +++ b/lib/sqlalchemy/databases/__init__.py @@ -5,5 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -__all__ = ['sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird', - 'sybase', 'access'] +__all__ = [ + 'sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird', + 'sybase', 'access', 'maxdb', + ] diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py new file mode 100644 index 0000000000..fcf04bec90 --- /dev/null +++ b/lib/sqlalchemy/databases/maxdb.py @@ -0,0 +1,1083 @@ +# maxdb.py +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""Support for the MaxDB database. + +TODO: module docs! + +Overview +-------- + +The ``maxdb`` dialect is **experimental** and has only been tested on 7.6.03.007 +and 7.6.00.037. Of these, **only 7.6.03.007 will work** with SQLAlchemy's ORM. +The earlier version has severe ``LEFT JOIN`` limitations and will return +incorrect results from even very simple ORM queries. + +Only the native Python DB-API is currently supported. ODBC driver support +is a future enhancement. + +Implementation Notes +-------------------- + +Also check the DatabaseNotes page on the wiki for detailed information. + +For 'somecol.in_([])' to work, the IN operator's generation must be changed +to cast 'NULL' to a numeric, i.e. NUM(NULL). The DB-API doesn't accept a +bind parameter there, so that particular generation must inline the NULL value, +which depends on [ticket:807]. + +The DB-API is very picky about where bind params may be used in queries. + +Bind params for some functions (e.g. MOD) need type information supplied. +The dialect does not yet do this automatically. + +Max will occasionally throw up 'bad sql, compile again' exceptions for +perfectly valid SQL. The dialect does not currently handle these, more +research is needed. + +MaxDB 7.5 and Sap DB <= 7.4 reportedly do not support schemas. A very +slightly different version of this dialect would be required to support +those versions, and can easily be added if there is demand. Some other +required components such as an Max-aware 'old oracle style' join compiler +(thetas with (+) outer indicators) are already done and available for +integration- email the devel list if you're interested in working on +this. +""" + +import datetime, itertools, re, warnings + +from sqlalchemy import exceptions, schema, sql, util +from sqlalchemy.sql import operators as sql_operators, expression as sql_expr +from sqlalchemy.sql import compiler, visitors +from sqlalchemy.engine import base as engine_base, default +from sqlalchemy import types as sqltypes + + +__all__ = [ + 'MaxString', 'MaxUnicode', 'MaxChar', 'MaxText', 'MaxInteger', + 'MaxSmallInteger', 'MaxNumeric', 'MaxFloat', 'MaxTimestamp', + 'MaxDate', 'MaxTime', 'MaxBoolean', 'MaxBlob', + ] + + +class _StringType(sqltypes.String): + _type = None + + def __init__(self, length=None, encoding=None, **kw): + super(_StringType, self).__init__(length=length, **kw) + self.encoding = encoding + + def get_col_spec(self): + if self.length is None: + spec = 'LONG' + else: + spec = '%s(%s)' % (self._type, self.length) + + if self.encoding is not None: + spec = ' '.join([spec, self.encoding.upper()]) + return spec + + def bind_processor(self, dialect): + if self.encoding == 'unicode': + return None + else: + def process(value): + if isinstance(value, unicode): + return value.encode(dialect.encoding) + else: + return value + return process + + def result_processor(self, dialect): + def process(value): + while True: + if value is None: + return None + elif isinstance(value, unicode): + return value + elif isinstance(value, str): + if self.convert_unicode or dialect.convert_unicode: + return value.decode(dialect.encoding) + else: + return value + elif hasattr(value, 'read'): + # some sort of LONG, snarf and retry + value = value.read(value.remainingLength()) + continue + else: + # unexpected type, return as-is + return value + return process + + +class MaxString(_StringType): + _type = 'VARCHAR' + + def __init__(self, *a, **kw): + super(MaxString, self).__init__(*a, **kw) + + +class MaxUnicode(_StringType): + _type = 'VARCHAR' + + def __init__(self, length=None, **kw): + super(MaxUnicode, self).__init__(length=length, encoding='unicode') + + +class MaxChar(_StringType): + _type = 'CHAR' + + +class MaxText(_StringType): + _type = 'LONG' + + def __init__(self, *a, **kw): + super(MaxText, self).__init__(*a, **kw) + + def get_col_spec(self): + spec = 'LONG' + if self.encoding is not None: + spec = ' '.join((spec, self.encoding)) + elif self.convert_unicode: + spec = ' '.join((spec, 'UNICODE')) + + return spec + + +class MaxInteger(sqltypes.Integer): + def get_col_spec(self): + return 'INTEGER' + + +class MaxSmallInteger(MaxInteger): + def get_col_spec(self): + return 'SMALLINT' + + +class MaxNumeric(sqltypes.Numeric): + """The NUMERIC (also FIXED, DECIMAL) data type.""" + + def get_col_spec(self): + if self.length and self.precision: + return 'NUMERIC(%s, %s)' % (self.precision, self.length) + elif self.length: + return 'NUMERIC(%s)' % self.length + else: + return 'INTEGER' + + +class MaxFloat(sqltypes.Float): + """The FLOAT data type.""" + + def get_col_spec(self): + if self.precision is None: + return 'FLOAT' + else: + return 'FLOAT(%s)' % (self.precision,) + + +class MaxTimestamp(sqltypes.DateTime): + def get_col_spec(self): + return 'TIMESTAMP' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y%m%d%H%M%S" + ("%06u" % ms)) + elif dialect.datetimeformat == 'iso': + ms = getattr(value, 'microsecond', 0) + return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[4:6], value[6:8], + value[8:10], value[10:12], value[12:14], + value[14:])]) + elif dialect.datetimeformat == 'iso': + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[5:7], value[8:10], + value[11:13], value[14:16], value[17:19], + value[20:])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxDate(sqltypes.Date): + def get_col_spec(self): + return 'DATE' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%Y%m%d") + elif dialect.datetimeformat == 'iso': + return value.strftime("%Y-%m-%d") + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + return datetime.date( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + elif dialect.datetimeformat == 'iso': + return datetime.date( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxTime(sqltypes.Time): + def get_col_spec(self): + return 'TIME' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, basestring): + return value + elif dialect.datetimeformat == 'internal': + return value.strftime("%H%M%S") + elif dialect.datetimeformat == 'iso': + return value.strftime("%H-%M-%S") + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + elif dialect.datetimeformat == 'internal': + t = datetime.time( + *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) + return t + elif dialect.datetimeformat == 'iso': + return datetime.time( + *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) + else: + raise exceptions.InvalidRequestError( + "datetimeformat '%s' is not supported." % ( + dialect.datetimeformat,)) + return process + + +class MaxBoolean(sqltypes.Boolean): + def get_col_spec(self): + return 'BOOLEAN' + + +class MaxBlob(sqltypes.Binary): + def get_col_spec(self): + return 'LONG BYTE' + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return str(value) + return process + + def result_processor(self, dialect): + def process(value): + if value is None: + return None + else: + return value.read(value.remainingLength()) + return process + + +colspecs = { + sqltypes.Integer: MaxInteger, + sqltypes.Smallinteger: MaxSmallInteger, + sqltypes.Numeric: MaxNumeric, + sqltypes.Float: MaxFloat, + sqltypes.DateTime: MaxTimestamp, + sqltypes.Date: MaxDate, + sqltypes.Time: MaxTime, + sqltypes.String: MaxString, + sqltypes.Binary: MaxBlob, + sqltypes.Boolean: MaxBoolean, + sqltypes.TEXT: MaxText, + sqltypes.CHAR: MaxChar, + sqltypes.TIMESTAMP: MaxTimestamp, + sqltypes.BLOB: MaxBlob, + sqltypes.Unicode: MaxUnicode, + } + +ischema_names = { + 'boolean': MaxBoolean, + 'int': MaxInteger, + 'integer': MaxInteger, + 'varchar': MaxString, + 'char': MaxChar, + 'character': MaxChar, + 'fixed': MaxNumeric, + 'float': MaxFloat, + 'long': MaxText, + 'long binary': MaxBlob, + 'long unicode': MaxText, + 'long': MaxText, + 'timestamp': MaxTimestamp, + 'date': MaxDate, + 'time': MaxTime + } + + +class MaxDBExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + # DB-API bug: if there were any functions as values, + # then do another select and pull CURRVAL from the + # autoincrement column's implicit sequence... ugh + if self.compiled.isinsert and not self.executemany: + table = self.compiled.statement.table + index, serial_col = _autoserial_column(table) + + if serial_col and (not self.compiled._safeserial or + not(self._last_inserted_ids) or + self._last_inserted_ids[index] in (None, 0)): + if table.schema: + sql = "SELECT %s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + else: + sql = "SELECT CURRENT_SCHEMA.%s.CURRVAL FROM DUAL" % ( + self.compiled.preparer.format_table(table)) + + if self.connection.engine._should_log_info: + self.connection.engine.logger.info(sql) + rs = self.cursor.execute(sql) + id = rs.fetchone()[0] + + if self.connection.engine._should_log_debug: + self.connection.engine.logger.debug([id]) + if not self._last_inserted_ids: + # This shouldn't ever be > 1? Right? + self._last_inserted_ids = \ + [None] * len(table.primary_key.columns) + self._last_inserted_ids[index] = id + + super(MaxDBExecutionContext, self).post_exec() + + def get_result_proxy(self): + if self.cursor.description is not None: + for column in self.cursor.description: + if column[1] in ('Long Binary', 'Long', 'Long Unicode'): + return MaxDBResultProxy(self) + return engine_base.ResultProxy(self) + + +class MaxDBCachedColumnRow(engine_base.RowProxy): + """A RowProxy that only runs result_processors once per column.""" + + def __init__(self, parent, row): + super(MaxDBCachedColumnRow, self).__init__(parent, row) + self.columns = {} + self._row = row + self._parent = parent + + def _get_col(self, key): + if key not in self.columns: + self.columns[key] = self._parent._get_col(self._row, key) + return self.columns[key] + + def __iter__(self): + for i in xrange(len(self._row)): + yield self._get_col(i) + + def __repr__(self): + return repr(list(self)) + + def __eq__(self, other): + return ((other is self) or + (other == tuple([self._get_col(key) + for key in xrange(len(self._row))]))) + def __getitem__(self, key): + if isinstance(key, slice): + indices = key.indices(len(self._row)) + return tuple([self._get_col(i) for i in xrange(*indices)]) + else: + return self._get_col(key) + + def __getattr__(self, name): + try: + return self._get_col(name) + except KeyError: + raise AttributeError(name) + + +class MaxDBResultProxy(engine_base.ResultProxy): + _process_row = MaxDBCachedColumnRow + + +class MaxDBDialect(default.DefaultDialect): + supports_alter = True + supports_unicode_statements = True + max_identifier_length = 32 + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + preexecute_sequences = True + + # MaxDB-specific + datetimeformat = 'internal' + + def __init__(self, _raise_known_sql_errors=False, **kw): + super(MaxDBDialect, self).__init__(**kw) + self._raise_known = _raise_known_sql_errors + + def dbapi(cls): + from sapdb import dbapi as _dbapi + return _dbapi + dbapi = classmethod(dbapi) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username='user') + opts.update(url.query) + return [], opts + + def type_descriptor(self, typeobj): + if isinstance(typeobj, type): + typeobj = typeobj() + if isinstance(typeobj, sqltypes.Unicode): + return typeobj.adapt(MaxUnicode) + else: + return sqltypes.adapt_type(typeobj, colspecs) + + def dbapi_type_map(self): + if self.dbapi is None: + return {} + else: + return { + 'Long Binary': MaxBlob(), + 'Long byte_t': MaxBlob(), + 'Long Unicode': MaxText(), + 'Timestamp': MaxTimestamp(), + 'Date': MaxDate(), + 'Time': MaxTime(), + datetime.datetime: MaxTimestamp(), + datetime.date: MaxDate(), + datetime.time: MaxTime(), + } + + def create_execution_context(self, connection, **kw): + return MaxDBExecutionContext(self, connection, **kw) + + def do_execute(self, cursor, statement, parameters, context=None): + res = cursor.execute(statement, parameters) + if isinstance(res, int) and context is not None: + context._rowcount = res + + def do_release_savepoint(self, connection, name): + # Does MaxDB truly support RELEASE SAVEPOINT ? All my attempts + # produce "SUBTRANS COMMIT/ROLLBACK not allowed without SUBTRANS + # BEGIN SQLSTATE: I7065" + # Note that ROLLBACK TO works fine. In theory, a RELEASE should + # just free up some transactional resources early, before the overall + # COMMIT/ROLLBACK so omitting it should be relatively ok. + pass + + def get_default_schema_name(self, connection): + try: + return self._default_schema_name + except AttributeError: + name = self.identifier_preparer._normalize_name( + connection.execute('SELECT CURRENT_SCHEMA FROM DUAL').scalar()) + self._default_schema_name = name + return name + + def has_table(self, connection, table_name, schema=None): + denormalize = self.identifier_preparer._denormalize_name + bind = [denormalize(table_name)] + if schema is None: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME=? AND" + " TABLES.SCHEMANAME=CURRENT_SCHEMA ") + else: + sql = ("SELECT tablename FROM TABLES " + "WHERE TABLES.TABLENAME = ? AND" + " TABLES.SCHEMANAME=? ") + bind.append(denormalize(schema)) + + rp = connection.execute(sql, bind) + found = bool(rp.fetchone()) + rp.close() + return found + + def table_names(self, connection, schema): + if schema is None: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=CURRENT_SCHEMA ") + rs = connection.execute(sql) + else: + sql = (" SELECT TABLENAME FROM TABLES WHERE " + " SCHEMANAME=? ") + matchname = self.identifier_preparer._denormalize_name(schema) + rs = connection.execute(sql, matchname) + normalize = self.identifier_preparer._normalize_name + return [normalize(row[0]) for row in rs] + + def reflecttable(self, connection, table, include_columns): + denormalize = self.identifier_preparer._denormalize_name + normalize = self.identifier_preparer._normalize_name + + st = ('SELECT COLUMNNAME, MODE, DATATYPE, CODETYPE, LEN, DEC, ' + ' NULLABLE, "DEFAULT", DEFAULTFUNCTION ' + 'FROM COLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY POS') + + fk = ('SELECT COLUMNNAME, FKEYNAME, ' + ' REFSCHEMANAME, REFTABLENAME, REFCOLUMNNAME, RULE, ' + ' (CASE WHEN REFSCHEMANAME = CURRENT_SCHEMA ' + ' THEN 1 ELSE 0 END) AS in_schema ' + 'FROM FOREIGNKEYCOLUMNS ' + 'WHERE TABLENAME=? AND SCHEMANAME=%s ' + 'ORDER BY FKEYNAME ') + + params = [denormalize(table.name)] + if not table.schema: + st = st % 'CURRENT_SCHEMA' + fk = fk % 'CURRENT_SCHEMA' + else: + st = st % '?' + fk = fk % '?' + params.append(denormalize(table.schema)) + + rows = connection.execute(st, params).fetchall() + if not rows: + raise exceptions.NoSuchTableError(table.fullname) + + include_columns = util.Set(include_columns or []) + + for row in rows: + (name, mode, col_type, encoding, length, precision, + nullable, constant_def, func_def) = row + + name = normalize(name) + + if include_columns and name not in include_columns: + continue + + type_args, type_kw = [], {} + if col_type == 'FIXED': + type_args = length, precision + elif col_type in 'FLOAT': + type_args = length, + elif col_type in ('CHAR', 'VARCHAR'): + type_args = length, + type_kw['encoding'] = encoding + elif col_type == 'LONG': + type_kw['encoding'] = encoding + + try: + type_cls = ischema_names[col_type.lower()] + type_instance = type_cls(*type_args, **type_kw) + except KeyError: + warnings.warn(RuntimeWarning( + "Did not recognize type '%s' of column '%s'" % + (col_type, name))) + type_instance = sqltypes.NullType + + col_kw = {'autoincrement': False} + col_kw['nullable'] = (nullable == 'YES') + col_kw['primary_key'] = (mode == 'KEY') + + if func_def is not None: + if func_def.startswith('SERIAL'): + # strip current numbering + col_kw['default'] = schema.PassiveDefault( + sql.text('SERIAL')) + col_kw['autoincrement'] = True + else: + col_kw['default'] = schema.PassiveDefault( + sql.text(func_def)) + elif constant_def is not None: + col_kw['default'] = schema.PassiveDefault(sql.text( + "'%s'" % constant_def.replace("'", "''"))) + + table.append_column(schema.Column(name, type_instance, **col_kw)) + + fk_sets = itertools.groupby(connection.execute(fk, params), + lambda row: row.FKEYNAME) + for fkeyname, fkey in fk_sets: + fkey = list(fkey) + if include_columns: + key_cols = util.Set([r.COLUMNNAME for r in fkey]) + if key_cols != include_columns: + continue + + columns, referants = [], [] + quote = self.identifier_preparer._maybe_quote_identifier + + for row in fkey: + columns.append(normalize(row.COLUMNNAME)) + if table.schema or not row.in_schema: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFSCHEMANAME', 'REFTABLENAME', + 'REFCOLUMNNAME')])) + else: + referants.append('.'.join( + [quote(normalize(row[c])) + for c in ('REFTABLENAME', 'REFCOLUMNNAME')])) + + constraint_kw = {'name': fkeyname.lower()} + if fkey[0].RULE is not None: + rule = fkey[0].RULE + if rule.startswith('DELETE '): + rule = rule[7:] + constraint_kw['ondelete'] = rule + + table_kw = {} + if table.schema or not row.in_schema: + table_kw['schema'] = normalize(fkey[0].REFSCHEMANAME) + + ref_key = schema._get_table_key(normalize(fkey[0].REFTABLENAME), + table_kw.get('schema')) + if ref_key not in table.metadata.tables: + schema.Table(normalize(fkey[0].REFTABLENAME), + table.metadata, + autoload=True, autoload_with=connection, + **table_kw) + + constraint = schema.ForeignKeyConstraint(columns, referants, + **constraint_kw) + table.append_constraint(constraint) + + def has_sequence(self, connection, name): + # [ticket:726] makes this schema-aware. + denormalize = self.identifier_preparer._denormalize_name + sql = ("SELECT sequence_name FROM SEQUENCES " + "WHERE SEQUENCE_NAME=? ") + + rp = connection.execute(sql, denormalize(name)) + found = bool(rp.fetchone()) + rp.close() + return found + + +class MaxDBCompiler(compiler.DefaultCompiler): + operators = compiler.DefaultCompiler.operators.copy() + operators[sql_operators.mod] = lambda x, y: 'mod(%s, %s)' % (x, y) + + function_conversion = { + 'CURRENT_DATE': 'DATE', + 'CURRENT_TIME': 'TIME', + 'CURRENT_TIMESTAMP': 'TIMESTAMP', + } + + # These functions must be written without parens when called with no + # parameters. e.g. 'SELECT DATE FROM DUAL' not 'SELECT DATE() FROM DUAL' + bare_functions = util.Set([ + 'CURRENT_SCHEMA', 'DATE', 'TIME', 'TIMESTAMP', 'TIMEZONE', + 'TRANSACTION', 'USER', 'UID', 'USERGROUP', 'UTCDATE']) + + def default_from(self): + return ' FROM DUAL' + + def for_update_clause(self, select): + clause = select.for_update + if clause is True: + return " WITH LOCK EXCLUSIVE" + elif clause is None: + return "" + elif clause == "read": + return " WITH LOCK" + elif clause == "ignore": + return " WITH LOCK (IGNORE) EXCLUSIVE" + elif clause == "nowait": + return " WITH LOCK (NOWAIT) EXCLUSIVE" + elif isinstance(clause, basestring): + return " WITH LOCK %s" % clause.upper() + elif not clause: + return "" + else: + return " WITH LOCK EXCLUSIVE" + + def apply_function_parens(self, func): + if func.name.upper() in self.bare_functions: + return len(func.clauses) > 0 + else: + return True + + def visit_function(self, fn, **kw): + transform = self.function_conversion.get(fn.name.upper(), None) + if transform: + fn = fn._clone() + fn.name = transform + return super(MaxDBCompiler, self).visit_function(fn, **kw) + + def visit_cast(self, cast, **kwargs): + # MaxDB only supports casts * to NUMERIC, * to VARCHAR or + # date/time to VARCHAR. Casts of LONGs will fail. + if isinstance(cast.type, (sqltypes.Integer, sqltypes.Numeric)): + return "NUM(%s)" % self.process(cast.clause) + elif isinstance(cast.type, sqltypes.String): + return "CHR(%s)" % self.process(cast.clause) + else: + return self.process(cast.clause) + + def visit_sequence(self, sequence): + if sequence.optional: + return None + else: + return (self.dialect.identifier_preparer.format_sequence(sequence) + + ".NEXTVAL") + + class ColumnSnagger(visitors.ClauseVisitor): + def __init__(self): + self.count = 0 + self.column = None + def visit_column(self, column): + self.column = column + self.count += 1 + + def _find_labeled_columns(self, columns, use_labels=False): + labels = {} + for column in columns: + if isinstance(column, basestring): + continue + snagger = self.ColumnSnagger() + snagger.traverse(column) + if snagger.count == 1: + if isinstance(column, sql_expr._Label): + labels[unicode(snagger.column)] = column.name + elif use_labels: + labels[unicode(snagger.column)] = column._label + + return labels + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # ORDER BY clauses in DISTINCT queries must reference aliased + # inner columns by alias name, not true column name. + if order_by and getattr(select, '_distinct', False): + labels = self._find_labeled_columns(select.inner_columns, + select.use_labels) + if labels: + for needs_alias in labels.keys(): + r = re.compile(r'(^| )(%s)(,| |$)' % + re.escape(needs_alias)) + order_by = r.sub((r'\1%s\3' % labels[needs_alias]), + order_by) + + # No ORDER BY in subqueries. + if order_by: + if self.is_subquery(select): + # It's safe to simply drop the ORDER BY if there is no + # LIMIT. Right? Other dialects seem to get away with + # dropping order. + if select._limit: + raise exceptions.InvalidRequestError( + "MaxDB does not support ORDER BY in subqueries") + else: + return "" + return " ORDER BY " + order_by + else: + return "" + + def get_select_precolumns(self, select): + # Convert a subquery's LIMIT to TOP + sql = select._distinct and 'DISTINCT ' or '' + if self.is_subquery(select) and select._limit: + if select._offset: + raise exceptions.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + sql += 'TOP %s ' % select._limit + return sql + + def limit_clause(self, select): + # The docs say offsets are supported with LIMIT. But they're not. + # TODO: maybe emulate by adding a ROWNO/ROWNUM predicate? + if self.is_subquery(select): + # sub queries need TOP + return '' + elif select._offset: + raise exceptions.InvalidRequestError( + 'MaxDB does not support LIMIT with an offset.') + else: + return ' \n LIMIT %s' % (select._limit,) + + def visit_insert(self, insert): + self.isinsert = True + self._safeserial = True + + colparams = self._get_colparams(insert) + for value in (insert.parameters or {}).itervalues(): + if isinstance(value, sql_expr._Function): + self._safeserial = False + break + + return ''.join(('INSERT INTO ', + self.preparer.format_table(insert.table), + ' (', + ', '.join([self.preparer.format_column(c[0]) + for c in colparams]), + ') VALUES (', + ', '.join([c[1] for c in colparams]), + ')')) + + +class MaxDBDefaultRunner(engine_base.DefaultRunner): + def visit_sequence(self, seq): + if seq.optional: + return None + return self.execute_string("SELECT %s.NEXTVAL FROM DUAL" % ( + self.dialect.identifier_preparer.format_sequence(seq))) + + +class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = util.Set([ + 'abs', 'absolute', 'acos', 'adddate', 'addtime', 'all', 'alpha', + 'alter', 'any', 'ascii', 'asin', 'atan', 'atan2', 'avg', 'binary', + 'bit', 'boolean', 'byte', 'case', 'ceil', 'ceiling', 'char', + 'character', 'check', 'chr', 'column', 'concat', 'constraint', 'cos', + 'cosh', 'cot', 'count', 'cross', 'curdate', 'current', 'curtime', + 'database', 'date', 'datediff', 'day', 'dayname', 'dayofmonth', + 'dayofweek', 'dayofyear', 'dec', 'decimal', 'decode', 'default', + 'degrees', 'delete', 'digits', 'distinct', 'double', 'except', + 'exists', 'exp', 'expand', 'first', 'fixed', 'float', 'floor', 'for', + 'from', 'full', 'get_objectname', 'get_schema', 'graphic', 'greatest', + 'group', 'having', 'hex', 'hextoraw', 'hour', 'ifnull', 'ignore', + 'index', 'initcap', 'inner', 'insert', 'int', 'integer', 'internal', + 'intersect', 'into', 'join', 'key', 'last', 'lcase', 'least', 'left', + 'length', 'lfill', 'list', 'ln', 'locate', 'log', 'log10', 'long', + 'longfile', 'lower', 'lpad', 'ltrim', 'makedate', 'maketime', + 'mapchar', 'max', 'mbcs', 'microsecond', 'min', 'minute', 'mod', + 'month', 'monthname', 'natural', 'nchar', 'next', 'no', 'noround', + 'not', 'now', 'null', 'num', 'numeric', 'object', 'of', 'on', + 'order', 'packed', 'pi', 'power', 'prev', 'primary', 'radians', + 'real', 'reject', 'relative', 'replace', 'rfill', 'right', 'round', + 'rowid', 'rowno', 'rpad', 'rtrim', 'second', 'select', 'selupd', + 'serial', 'set', 'show', 'sign', 'sin', 'sinh', 'smallint', 'some', + 'soundex', 'space', 'sqrt', 'stamp', 'statistics', 'stddev', + 'subdate', 'substr', 'substring', 'subtime', 'sum', 'sysdba', + 'table', 'tan', 'tanh', 'time', 'timediff', 'timestamp', 'timezone', + 'to', 'toidentifier', 'transaction', 'translate', 'trim', 'trunc', + 'truncate', 'ucase', 'uid', 'unicode', 'union', 'update', 'upper', + 'user', 'usergroup', 'using', 'utcdate', 'utcdiff', 'value', 'values', + 'varchar', 'vargraphic', 'variance', 'week', 'weekofyear', 'when', + 'where', 'with', 'year', 'zoned' ]) + + def _normalize_name(self, name): + if name is None: + return None + if name.isupper(): + lc_name = name.lower() + if not self._requires_quotes(lc_name): + return lc_name + return name + + def _denormalize_name(self, name): + if name is None: + return None + elif (name.islower() and + not self._requires_quotes(name)): + return name.upper() + else: + return name + + def _maybe_quote_identifier(self, name): + if self._requires_quotes(name): + return self.quote_identifier(name) + else: + return name + + +class MaxDBSchemaGenerator(compiler.SchemaGenerator): + def get_column_specification(self, column, **kw): + colspec = [self.preparer.format_column(column), + column.type.dialect_impl(self.dialect).get_col_spec()] + + if not column.nullable: + colspec.append('NOT NULL') + + default = column.default + default_str = self.get_column_default_string(column) + + # No DDL default for columns specified with non-optional sequence- + # this defaulting behavior is entirely client-side. (And as a + # consequence, non-reflectable.) + if (default and isinstance(default, schema.Sequence) and + not default.optional): + pass + # Regular default + elif default_str is not None: + colspec.append('DEFAULT %s' % default_str) + # Assign DEFAULT SERIAL heuristically + elif column.primary_key and column.autoincrement: + # For SERIAL on a non-primary key member, use + # PassiveDefault(text('SERIAL')) + try: + first = [c for c in column.table.primary_key.columns + if (c.autoincrement and + (isinstance(c.type, sqltypes.Integer) or + (isinstance(c.type, MaxNumeric) and + c.type.precision)) and + not c.foreign_keys)].pop(0) + if column is first: + colspec.append('DEFAULT SERIAL') + except IndexError: + pass + + return ' '.join(colspec) + + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if isinstance(column.default.arg, basestring): + if isinstance(column.type, sqltypes.Integer): + return str(column.default.arg) + else: + return "'%s'" % column.default.arg + else: + return unicode(self._compile(column.default.arg, None)) + else: + return None + + def visit_sequence(self, sequence): + """Creates a SEQUENCE. + + TODO: move to module doc? + + start + With an integer value, set the START WITH option. + + increment + An integer value to increment by. Default is the database default. + + maxdb_minvalue + maxdb_maxvalue + With an integer value, sets the corresponding sequence option. + + maxdb_no_minvalue + maxdb_no_maxvalue + Defaults to False. If true, sets the corresponding sequence option. + + maxdb_cycle + Defaults to False. If true, sets the CYCLE option. + + maxdb_cache + With an integer value, sets the CACHE option. + + maxdb_no_cache + Defaults to False. If true, sets NOCACHE. + """ + + if (not sequence.optional and + (not self.checkfirst or + not self.dialect.has_sequence(self.connection, sequence.name))): + + ddl = ['CREATE SEQUENCE', + self.preparer.format_sequence(sequence)] + + sequence.increment = 1 + + if sequence.increment is not None: + ddl.extend(('INCREMENT BY', str(sequence.increment))) + + if sequence.start is not None: + ddl.extend(('START WITH', str(sequence.start))) + + opts = dict([(pair[0][6:].lower(), pair[1]) + for pair in sequence.kwargs.items() + if pair[0].startswith('maxdb_')]) + + if 'maxvalue' in opts: + ddl.extend(('MAXVALUE', str(opts['maxvalue']))) + elif opts.get('no_maxvalue', False): + ddl.append('NOMAXVALUE') + if 'minvalue' in opts: + ddl.extend(('MINVALUE', str(opts['minvalue']))) + elif opts.get('no_minvalue', False): + ddl.append('NOMINVALUE') + + if opts.get('cycle', False): + ddl.append('CYCLE') + + if 'cache' in opts: + ddl.extend(('CACHE', str(opts['cache']))) + elif opts.get('no_cache', False): + ddl.append('NOCACHE') + + self.append(' '.join(ddl)) + self.execute() + + +class MaxDBSchemaDropper(compiler.SchemaDropper): + def visit_sequence(self, sequence): + if (not sequence.optional and + (not self.checkfirst or + self.dialect.has_sequence(self.connection, sequence.name))): + self.append("DROP SEQUENCE %s" % + self.preparer.format_sequence(sequence)) + self.execute() + + +def _autoserial_column(table): + """Finds the effective DEFAULT SERIAL column of a Table, if any.""" + + for index, col in enumerate(table.primary_key.columns): + if (isinstance(col.type, (sqltypes.Integer, sqltypes.Numeric)) and + col.autoincrement): + if isinstance(col.default, schema.Sequence): + if col.default.optional: + return index, col + elif (col.default is None or + (not isinstance(col.default, schema.PassiveDefault))): + return index, col + + return None, None + +def descriptor(): + return {'name': 'maxdb', + 'description': 'MaxDB', + 'arguments': [ + ('user', "Database Username", None), + ('password', "Database Password", None), + ('database', "Database Name", None), + ('host', "Hostname", None)]} + +dialect = MaxDBDialect +dialect.preparer = MaxDBIdentifierPreparer +dialect.statement_compiler = MaxDBCompiler +dialect.schemagenerator = MaxDBSchemaGenerator +dialect.schemadropper = MaxDBSchemaDropper +dialect.defaultrunner = MaxDBDefaultRunner + diff --git a/test/dialect/maxdb.py b/test/dialect/maxdb.py new file mode 100644 index 0000000000..551c26b374 --- /dev/null +++ b/test/dialect/maxdb.py @@ -0,0 +1,92 @@ +"""MaxDB-specific tests.""" + +import testbase +import StringIO, sys +from sqlalchemy import * +from sqlalchemy import sql +from sqlalchemy.databases import maxdb +from testlib import * + + +# TODO +# - add "Database" test, a quick check for join behavior on different max versions +# - full max-specific reflection suite +# - datetime tests +# - decimal etc. tests +# - the orm/query 'test_has' destabilizes the server- cover here + +class BasicTest(AssertMixin): + def test_import(self): + return True + +class DBAPITest(AssertMixin): + """Asserts quirks in the native Python DB-API driver. + + If any of these fail, that's good- the bug is fixed! + """ + + @testing.supported('maxdb') + def test_dbapi_breaks_sequences(self): + con = testbase.db.connect().connection + + cr = con.cursor() + cr.execute('CREATE SEQUENCE busto START WITH 1 INCREMENT BY 1') + try: + vals = [] + for i in xrange(3): + cr.execute('SELECT busto.NEXTVAL FROM DUAL') + vals.append(cr.fetchone()[0]) + + # should be 1,2,3, but no... + self.assert_(vals != [1,2,3]) + # ...we get: + self.assert_(vals == [2,4,6]) + finally: + cr.execute('DROP SEQUENCE busto') + + @testing.supported('maxdb') + def test_dbapi_breaks_mod_binds(self): + con = testbase.db.connect().connection + + cr = con.cursor() + # OK + cr.execute('SELECT MOD(3, 2) FROM DUAL') + + # Broken! + try: + cr.execute('SELECT MOD(3, ?) FROM DUAL', [2]) + self.assert_(False) + except: + self.assert_(True) + + # OK + cr.execute('SELECT MOD(?, 2) FROM DUAL', [3]) + + @testing.supported('maxdb') + def test_dbapi_breaks_close(self): + dialect = testbase.db.dialect + cargs, ckw = dialect.create_connect_args(testbase.db.url) + + # There doesn't seem to be a way to test for this as it occurs in + # regular usage- the warning doesn't seem to go through 'warnings'. + con = dialect.dbapi.connect(*cargs, **ckw) + con.close() + del con # <-- exception during __del__ + + # But this does the same thing. + con = dialect.dbapi.connect(*cargs, **ckw) + self.assert_(con.close == con.__del__) + con.close() + try: + con.close() + self.assert_(False) + except dialect.dbapi.DatabaseError: + self.assert_(True) + + @testing.supported('maxdb') + def test_modulo_operator(self): + st = str(select([sql.column('col') % 5]).compile(testbase.db)) + self.assertEquals(st, 'SELECT mod(col, ?) FROM DUAL') + +if __name__ == "__main__": + testbase.main() diff --git a/test/engine/execute.py b/test/engine/execute.py index 28faf1102e..6cf3cccd97 100644 --- a/test/engine/execute.py +++ b/test/engine/execute.py @@ -18,7 +18,7 @@ class ExecuteTest(PersistTest): def tearDownAll(self): metadata.drop_all() - @testing.supported('sqlite') + @testing.supported('sqlite', 'maxdb') def test_raw_qmark(self): for conn in (testbase.db, testbase.db.connect()): conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack")) diff --git a/test/engine/reflection.py b/test/engine/reflection.py index af190649cc..534cdd2c1b 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -11,9 +11,10 @@ class ReflectionTest(PersistTest): @testing.exclude('mysql', '<', (4, 1, 1)) def testbasic(self): - use_function_defaults = testbase.db.engine.name == 'postgres' or testbase.db.engine.name == 'oracle' + use_function_defaults = testing.against('postgres', 'oracle', 'maxdb') - use_string_defaults = use_function_defaults or testbase.db.engine.__module__.endswith('sqlite') + use_string_defaults = (use_function_defaults or + testbase.db.engine.__module__.endswith('sqlite')) if use_function_defaults: defval = func.current_date() @@ -25,12 +26,11 @@ class ReflectionTest(PersistTest): if use_string_defaults: deftype2 = String defval2 = "im a default" - #deftype3 = DateTime - # the colon thing isnt working out for PG reflection just yet - #defval3 = '1999-09-09 00:00:00' deftype3 = Date - if testbase.db.engine.name == 'oracle': + if testing.against('oracle'): defval3 = text("to_date('09-09-1999', 'MM-DD-YYYY')") + elif testing.against('maxdb'): + defval3 = '19990909' else: defval3 = '1999-09-09' else: @@ -520,7 +520,7 @@ class ReflectionTest(PersistTest): # There's currently no way to calculate identifier case normalization # in isolation, so... - if testbase.db.engine.name in ('firebird', 'oracle'): + if testing.against('firebird', 'oracle', 'maxdb'): check_col = 'TRUE' else: check_col = 'true' @@ -689,7 +689,7 @@ class CreateDropTest(PersistTest): metadata.drop_all(bind=testbase.db) class UnicodeTest(PersistTest): - @testing.unsupported('sybase') + @testing.unsupported('sybase', 'maxdb') def test_basic(self): try: # the 'convert_unicode' should not get in the way of the reflection @@ -747,16 +747,16 @@ class SchemaTest(PersistTest): assert buf.index("CREATE TABLE someschema.table1") > -1 assert buf.index("CREATE TABLE someschema.table2") > -1 - @testing.supported('mysql','postgres') + @testing.supported('maxdb', 'mysql', 'postgres') def test_explicit_default_schema(self): engine = testbase.db schema = engine.dialect.get_default_schema_name(engine) - #engine.echo = True - if testbase.db.name == 'mysql': + if testing.against('mysql'): schema = testbase.db.url.database - else: + elif testing.against('postgres'): schema = 'public' + metadata = MetaData(testbase.db) table1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), @@ -768,6 +768,7 @@ class SchemaTest(PersistTest): metadata.create_all() metadata.create_all(checkfirst=True) metadata.clear() + table1 = Table('table1', metadata, autoload=True, schema=schema) table2 = Table('table2', metadata, autoload=True, schema=schema) metadata.drop_all() diff --git a/test/engine/transaction.py b/test/engine/transaction.py index 4c7c5ec040..b11065933a 100644 --- a/test/engine/transaction.py +++ b/test/engine/transaction.py @@ -160,7 +160,7 @@ class TransactionTest(PersistTest): connection.close() - @testing.supported('postgres', 'mysql', 'oracle') + @testing.supported('postgres', 'mysql', 'oracle', 'maxdb') @testing.exclude('mysql', '<', (5, 0, 3)) def testnestedsubtransactionrollback(self): connection = testbase.db.connect() @@ -178,7 +178,7 @@ class TransactionTest(PersistTest): ) connection.close() - @testing.supported('postgres', 'mysql', 'oracle') + @testing.supported('postgres', 'mysql', 'oracle', 'maxdb') @testing.exclude('mysql', '<', (5, 0, 3)) def testnestedsubtransactioncommit(self): connection = testbase.db.connect() @@ -196,7 +196,7 @@ class TransactionTest(PersistTest): ) connection.close() - @testing.supported('postgres', 'mysql', 'oracle') + @testing.supported('postgres', 'mysql', 'oracle', 'maxdb') @testing.exclude('mysql', '<', (5, 0, 3)) def testrollbacktosubtransaction(self): connection = testbase.db.connect() @@ -636,7 +636,7 @@ class ForUpdateTest(PersistTest): break con.close() - @testing.supported('mysql', 'oracle', 'postgres') + @testing.supported('mysql', 'oracle', 'postgres', 'maxdb') def testqueued_update(self): """Test SELECT FOR UPDATE with concurrent modifications. @@ -698,7 +698,7 @@ class ForUpdateTest(PersistTest): return errors - @testing.supported('mysql', 'oracle', 'postgres') + @testing.supported('mysql', 'oracle', 'postgres', 'maxdb') def testqueued_select(self): """Simple SELECT FOR UPDATE conflict test""" @@ -707,7 +707,7 @@ class ForUpdateTest(PersistTest): sys.stderr.write("Failure: %s\n" % e) self.assert_(len(errors) == 0) - @testing.supported('oracle', 'postgres') + @testing.supported('oracle', 'postgres', 'maxdb') def testnowait_select(self): """Simple SELECT FOR UPDATE NOWAIT conflict test""" diff --git a/test/orm/assorted_eager.py b/test/orm/assorted_eager.py index 353560826e..1eae7a5454 100644 --- a/test/orm/assorted_eager.py +++ b/test/orm/assorted_eager.py @@ -13,10 +13,14 @@ class EagerTest(AssertMixin): dbmeta = MetaData(testbase.db) # determine a literal value for "false" based on the dialect - false = False + # FIXME: this PassiveDefault setup is bogus. bp = Boolean().dialect_impl(testbase.db.dialect).bind_processor(testbase.db.dialect) if bp: - false = bp(false) + false = str(bp(False)) + elif testing.against('maxdb'): + false = text('FALSE') + else: + false = str(False) owners = Table ( 'owners', dbmeta , Column ( 'id', Integer, primary_key=True, nullable=False ), @@ -31,7 +35,7 @@ class EagerTest(AssertMixin): options = Table ( 'options', dbmeta , Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ), Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ), - Column ( 'someoption', Boolean, PassiveDefault(str(false)), nullable=False ) ) + Column ( 'someoption', Boolean, PassiveDefault(false), nullable=False ) ) dbmeta.create_all() diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 7ecae957ee..a30c4d7faa 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -244,7 +244,7 @@ class EagerTest(QueryTest): noeagers = create_session().query(User).from_statement("select * from users").all() assert 'orders' not in noeagers[0].__dict__ assert 'addresses' not in noeagers[0].__dict__ - + def test_limit(self): """test limit operations combined with lazy-load relationships.""" @@ -260,7 +260,7 @@ class EagerTest(QueryTest): sess = create_session() q = sess.query(User) - if testbase.db.engine.name == 'mssql': + if testing.against('mysql'): l = q.limit(2).all() assert fixtures.user_all_result[:2] == l else: @@ -317,7 +317,7 @@ class EagerTest(QueryTest): q = sess.query(User) - if testbase.db.engine.name != 'mssql': + if not testing.against('maxdb', 'mssql'): l = q.join('orders').order_by(Order.user_id.desc()).limit(2).offset(1) assert [ User(id=9, diff --git a/test/orm/entity.py b/test/orm/entity.py index ef0932032f..5ef01b8829 100644 --- a/test/orm/entity.py +++ b/test/orm/entity.py @@ -117,8 +117,8 @@ class EntityTest(AssertMixin): u2.addresses.append(a2) sess.save(u2, entity_name='user2') print u2.__dict__ - - sess.flush() + + sess.flush() assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')] diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index a033d61eac..000fddc453 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -11,7 +11,8 @@ class O2MTest(ORMTest): global foo, bar, blub # the 'data' columns are to appease SQLite which cant handle a blank INSERT foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), Column('data', String(20))) bar = Table('bar', metadata, @@ -68,7 +69,8 @@ class GetTest(ORMTest): def define_tables(self, metadata): global foo, bar, blub foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), Column('type', String(30)), Column('data', String(20))) @@ -199,14 +201,17 @@ class EagerLazyTest(ORMTest): LazyLoader constructs the right query condition.""" def define_tables(self, metadata): global foo, bar, bar_foo - foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq'), primary_key=True), - Column('data', String(30))) - bar = Table('bar', metadata, Column('id', Integer, ForeignKey('foo.id'), primary_key=True), - Column('data', String(30))) + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), + Column('data', String(30))) + bar = Table('bar', metadata, + Column('id', Integer, ForeignKey('foo.id'), primary_key=True), + Column('data', String(30))) bar_foo = Table('bar_foo', metadata, - Column('bar_id', Integer, ForeignKey('bar.id')), - Column('foo_id', Integer, ForeignKey('foo.id')) + Column('bar_id', Integer, ForeignKey('bar.id')), + Column('foo_id', Integer, ForeignKey('foo.id')) ) def testbasic(self): diff --git a/test/orm/inheritance/manytomany.py b/test/orm/inheritance/manytomany.py index df00f39d0b..343345aa36 100644 --- a/test/orm/inheritance/manytomany.py +++ b/test/orm/inheritance/manytomany.py @@ -12,35 +12,28 @@ class InheritTest(ORMTest): global groups global user_group_map - principals = Table( - 'principals', - metadata, - Column('principal_id', Integer, Sequence('principal_id_seq', optional=False), primary_key=True), - Column('name', String(50), nullable=False), - ) - - users = Table( - 'prin_users', - metadata, - Column('principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), + principals = Table('principals', metadata, + Column('principal_id', Integer, + Sequence('principal_id_seq', optional=False), + primary_key=True), + Column('name', String(50), nullable=False)) + + users = Table('prin_users', metadata, + Column('principal_id', Integer, + ForeignKey('principals.principal_id'), primary_key=True), Column('password', String(50), nullable=False), Column('email', String(50), nullable=False), - Column('login_id', String(50), nullable=False), - - ) + Column('login_id', String(50), nullable=False)) - groups = Table( - 'prin_groups', - metadata, - Column( 'principal_id', Integer, ForeignKey('principals.principal_id'), primary_key=True), + groups = Table('prin_groups', metadata, + Column('principal_id', Integer, + ForeignKey('principals.principal_id'), primary_key=True)) - ) - - user_group_map = Table( - 'prin_user_group_map', - metadata, - Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), primary_key=True ), - Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), primary_key=True ), + user_group_map = Table('prin_user_group_map', metadata, + Column('user_id', Integer, ForeignKey( "prin_users.principal_id"), + primary_key=True ), + Column('group_id', Integer, ForeignKey( "prin_groups.principal_id"), + primary_key=True ), ) def testbasic(self): @@ -56,18 +49,12 @@ class InheritTest(ORMTest): pass mapper(Principal, principals) - mapper( - User, - users, - inherits=Principal - ) + mapper(User, users, inherits=Principal) - mapper( - Group, - groups, - inherits=Principal, - properties=dict( users = relation(User, secondary=user_group_map, lazy=True, backref="groups") ) - ) + mapper(Group, groups, inherits=Principal, properties={ + 'users': relation(User, secondary=user_group_map, + lazy=True, backref="groups") + }) g = Group(name="group1") g.users.append(User(name="user1", password="pw", email="foo@bar.com", login_id="lg1")) @@ -81,7 +68,8 @@ class InheritTest2(ORMTest): def define_tables(self, metadata): global foo, bar, foo_bar foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_id_seq', optional=True), + primary_key=True), Column('data', String(20)), ) @@ -155,7 +143,8 @@ class InheritTest3(ORMTest): # the 'data' columns are to appease SQLite which cant handle a blank INSERT foo = Table('foo', metadata, - Column('id', Integer, Sequence('foo_seq'), primary_key=True), + Column('id', Integer, Sequence('foo_seq', optional=True), + primary_key=True), Column('data', String(20))) bar = Table('bar', metadata, diff --git a/test/orm/lazy_relations.py b/test/orm/lazy_relations.py index 0440d11a39..b8e92c1637 100644 --- a/test/orm/lazy_relations.py +++ b/test/orm/lazy_relations.py @@ -146,7 +146,7 @@ class LazyTest(QueryTest): sess = create_session() q = sess.query(User) - if testbase.db.engine.name == 'mssql': + if testing.against('maxdb', 'mssql'): l = q.limit(2).all() assert fixtures.user_all_result[:2] == l else: diff --git a/test/orm/query.py b/test/orm/query.py index f96d6fc435..775f7357e0 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -6,6 +6,7 @@ from sqlalchemy.sql import compiler from sqlalchemy.engine import default from sqlalchemy.orm import * from testlib import * +from testlib import engines from testlib.fixtures import * class QueryTest(FixtureTest): @@ -107,20 +108,26 @@ class GetTest(QueryTest): assert u2.name =='jack' assert a not in u2.addresses - @testing.exclude('mysql', '<', (5, 0)) # fixme + @testing.exclude('mysql', '<', (4, 1)) def test_unicode(self): """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail on postgres, mysql and oracle unless it is converted to an encoded string""" - table = Table('unicode_data', users.metadata, + metadata = MetaData(engines.utf8_engine()) + table = Table('unicode_data', metadata, Column('id', Unicode(40), primary_key=True), Column('data', Unicode(40))) - table.create() - ustring = 'petit voix m\xe2\x80\x99a '.decode('utf-8') - table.insert().execute(id=ustring, data=ustring) - class LocalFoo(Base):pass - mapper(LocalFoo, table) - assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring) + try: + metadata.create_all() + ustring = 'petit voix m\xe2\x80\x99a'.decode('utf-8') + table.insert().execute(id=ustring, data=ustring) + class LocalFoo(Base): + pass + mapper(LocalFoo, table) + self.assertEquals(create_session().query(LocalFoo).get(ustring), + LocalFoo(id=ustring, data=ustring)) + finally: + metadata.drop_all() def test_populate_existing(self): s = create_session() @@ -261,6 +268,7 @@ class FilterTest(QueryTest): def test_basic(self): assert [User(id=7), User(id=8), User(id=9),User(id=10)] == create_session().query(User).all() + @testing.unsupported('maxdb') def test_limit(self): assert [User(id=8), User(id=9)] == create_session().query(User).limit(2).offset(1).all() @@ -305,7 +313,9 @@ class FilterTest(QueryTest): filter(User.addresses.any(id=4)).all() assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all() - + + # THIS ONE + @testing.unsupported('maxdb') def test_has(self): sess = create_session() assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 8adf1a980f..44d229985e 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -46,7 +46,8 @@ class VersioningTest(ORMTest): def define_tables(self, metadata): global version_table version_table = Table('version_test', metadata, - Column('id', Integer, Sequence('version_test_seq'), primary_key=True ), + Column('id', Integer, Sequence('version_test_seq', optional=True), + primary_key=True ), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False) ) @@ -56,7 +57,7 @@ class VersioningTest(ORMTest): s = Session(scope=None) class Foo(object):pass mapper(Foo, version_table, version_id_col=version_table.c.version_id) - f1 =Foo(value='f1', _sa_session=s) + f1 = Foo(value='f1', _sa_session=s) f2 = Foo(value='f2', _sa_session=s) s.commit() @@ -523,7 +524,7 @@ class ClauseAttributesTest(ORMTest): Column('id', Integer, Sequence('users_id_seq', optional=True), primary_key=True), Column('name', String(30)), Column('counter', Integer, default=1)) - + def test_update(self): class User(object): pass @@ -571,9 +572,8 @@ class ClauseAttributesTest(ORMTest): sess.save(u) sess.flush() assert u.counter == 5 - - + class PassiveDeletesTest(ORMTest): def define_tables(self, metadata): global mytable,myothertable @@ -683,7 +683,7 @@ class ExtraPassiveDeletesTest(ORMTest): try: sess.commit() assert False - except (exceptions.IntegrityError, exceptions.OperationalError): + except exceptions.DBAPIError: assert True diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 493545b228..7856d758be 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -25,6 +25,7 @@ class CaseTest(PersistTest): def tearDownAll(self): info_table.drop() + @testing.unsupported('maxdb') def testcase(self): inner = select([case([ [info_table.c.pk < 3, diff --git a/test/sql/query.py b/test/sql/query.py index ba29d6a8f6..67384073cb 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -221,7 +221,7 @@ class QueryTest(PersistTest): users.delete(users.c.user_name == 'fred').execute() print repr(users.select().execute().fetchall()) - + def testselectlimit(self): users.insert().execute(user_id=1, user_name='john') users.insert().execute(user_id=2, user_name='jack') @@ -233,7 +233,7 @@ class QueryTest(PersistTest): r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r)) - @testing.unsupported('mssql') + @testing.unsupported('mssql', 'maxdb') def testselectlimitoffset(self): users.insert().execute(user_id=1, user_name='john') users.insert().execute(user_id=2, user_name='jack') @@ -247,8 +247,8 @@ class QueryTest(PersistTest): r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r==[(6, 'ralph'), (7, 'fido')]) - @testing.supported('mssql') - def testselectlimitoffset_mssql(self): + @testing.supported('mssql', 'maxdb') + def test_select_limit_nooffset(self): try: r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall() assert False # InvalidRequestError should have been raised @@ -423,8 +423,12 @@ class QueryTest(PersistTest): assert (x == y == z) is True def test_update_functions(self): - """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances, - and that column-level defaults get overridden""" + """ + Tests sending functions and SQL expressions to the VALUES and SET + clauses of INSERT/UPDATE instances, and that column-level defaults + get overridden. + """ + meta = MetaData(testbase.db) t = Table('t1', meta, Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), @@ -444,7 +448,6 @@ class QueryTest(PersistTest): r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() id = r.last_inserted_ids()[0] - assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 t.update(values={t.c.value:func.length("asdf")}).execute() assert t.select().execute().fetchone()['value'] == 4 @@ -453,7 +456,8 @@ class QueryTest(PersistTest): t2.insert(values=dict(value=func.length("one"))).execute() t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") - assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(7,None), (3,None), (-14,"hi")] + res = exec_sorted(select([t2.c.value, t2.c.stuff])) + self.assertEquals(res, [(-14, 'hi'), (3, None), (7, None)]) t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] @@ -506,7 +510,7 @@ class QueryTest(PersistTest): self.assertEqual([x.lower() for x in r.keys()], ['user_name', 'user_id']) self.assertEqual(r.values(), ['foo', 1]) - @testing.unsupported('oracle', 'firebird') + @testing.unsupported('oracle', 'firebird', 'maxdb') def test_column_accessor_shadow(self): meta = MetaData(testbase.db) shadowed = Table('test_shadowed', meta, @@ -590,6 +594,7 @@ class QueryTest(PersistTest): finally: table.drop() + @testing.unsupported('maxdb') def test_in_filtering(self): """test the behavior of the in_() function.""" @@ -717,6 +722,7 @@ class CompoundTest(PersistTest): ('ccc', 'aaa')] self.assertEquals(u.execute().fetchall(), wanted) + @testing.unsupported('maxdb') def test_union_ordered_alias(self): (s1, s2) = ( select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], @@ -1125,9 +1131,11 @@ class OperatorTest(PersistTest): def tearDownAll(self): metadata.drop_all() + @testing.unsupported('maxdb') def test_modulo(self): self.assertEquals( - select([flds.c.intcol % 3], order_by=flds.c.idcol).execute().fetchall(), + select([flds.c.intcol % 3], + order_by=flds.c.idcol).execute().fetchall(), [(2,),(1,)] ) diff --git a/test/sql/rowcount.py b/test/sql/rowcount.py index 095f79200d..4bd52b9faa 100644 --- a/test/sql/rowcount.py +++ b/test/sql/rowcount.py @@ -47,6 +47,7 @@ class FoundRowsTest(AssertMixin): # WHERE matches 3, 3 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='Z') + print "expecting 3, dialect reports %s" % r.rowcount if testbase.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 @@ -54,6 +55,7 @@ class FoundRowsTest(AssertMixin): # WHERE matches 3, 0 rows changed department = employees_table.c.department r = employees_table.update(department=='C').execute(department='C') + print "expecting 3, dialect reports %s" % r.rowcount if testbase.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 @@ -61,6 +63,7 @@ class FoundRowsTest(AssertMixin): # WHERE matches 3, 3 rows deleted department = employees_table.c.department r = employees_table.delete(department=='C').execute() + print "expecting 3, dialect reports %s" % r.rowcount if testbase.db.dialect.supports_sane_rowcount: assert r.rowcount == 3 diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 71313ec42d..101efb79ba 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -41,13 +41,14 @@ class MyUnicodeType(types.TypeDecorator): impl = Unicode def bind_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).bind_processor(dialect) + impl_processor = super(MyUnicodeType, self).bind_processor(dialect) or (lambda value:value) + def process(value): return "UNI_BIND_IN"+ impl_processor(value) return process def result_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).result_processor(dialect) + impl_processor = super(MyUnicodeType, self).result_processor(dialect) or (lambda value:value) def process(value): return impl_processor(value) + "UNI_BIND_OUT" return process @@ -264,9 +265,10 @@ class UnicodeTest(AssertMixin): unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - print repr(x['unicode_varchar']) - print repr(x['unicode_text']) - print repr(x['plain_varchar']) + print 0, repr(unicodedata) + print 1, repr(x['unicode_varchar']) + print 2, repr(x['unicode_text']) + print 3, repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) if isinstance(x['plain_varchar'], unicode): @@ -293,9 +295,10 @@ class UnicodeTest(AssertMixin): unicode_text=unicodedata, plain_varchar=rawdata) x = unicode_table.select().execute().fetchone() - print repr(x['unicode_varchar']) - print repr(x['unicode_text']) - print repr(x['plain_varchar']) + print 0, repr(unicodedata) + print 1, repr(x['unicode_varchar']) + print 2, repr(x['unicode_text']) + print 3, repr(x['plain_varchar']) self.assert_(isinstance(x['unicode_varchar'], unicode) and x['unicode_varchar'] == unicodedata) self.assert_(isinstance(x['unicode_text'], unicode) and x['unicode_text'] == unicodedata) self.assert_(isinstance(x['plain_varchar'], unicode) and x['plain_varchar'] == unicodedata) @@ -363,7 +366,7 @@ class DateTest(AssertMixin): global users_with_date, insert_data db = testbase.db - if db.engine.name == 'oracle': + if testing.against('oracle'): import sqlalchemy.databases.oracle as oracle insert_data = [ [7, 'jack', @@ -393,8 +396,11 @@ class DateTest(AssertMixin): time_micro = 999 # Missing or poor microsecond support: - if db.engine.name in ('mssql', 'mysql', 'firebird'): + if testing.against('mssql', 'mysql', 'firebird'): datetime_micro, time_micro = 0, 0 + # No microseconds for TIME + elif testing.against('maxdb'): + time_micro = 0 insert_data = [ [7, 'jack', diff --git a/test/sql/unicode.py b/test/sql/unicode.py index 55f3f6bc56..03673eb4d4 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -8,7 +8,7 @@ from testlib.engines import utf8_engine class UnicodeSchemaTest(PersistTest): - @testing.unsupported('oracle', 'sybase') + @testing.unsupported('maxdb', 'oracle', 'sybase') def setUpAll(self): global unicode_bind, metadata, t1, t2, t3 @@ -55,20 +55,20 @@ class UnicodeSchemaTest(PersistTest): ) metadata.create_all() - @testing.unsupported('oracle', 'sybase') + @testing.unsupported('maxdb', 'oracle', 'sybase') def tearDown(self): if metadata.tables: t3.delete().execute() t2.delete().execute() t1.delete().execute() - @testing.unsupported('oracle', 'sybase') + @testing.unsupported('maxdb', 'oracle', 'sybase') def tearDownAll(self): global unicode_bind metadata.drop_all() del unicode_bind - @testing.unsupported('oracle', 'sybase') + @testing.unsupported('maxdb', 'oracle', 'sybase') def test_insert(self): t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5}) t2.insert().execute({'a':1, 'b':1}) @@ -81,7 +81,7 @@ class UnicodeSchemaTest(PersistTest): assert t2.select().execute().fetchall() == [(1, 1)] assert t3.select().execute().fetchall() == [(1, 5, 1, 1)] - @testing.unsupported('oracle', 'sybase') + @testing.unsupported('maxdb', 'oracle', 'sybase') def test_reflect(self): t1.insert().execute({u'méil':2, u'\u6e2c\u8a66':7}) t2.insert().execute({'a':2, 'b':2}) -- 2.47.3