style of Hibernate
- sql
+ - all "type" keyword arguments, such as those to bindparam(), column(),
+ Column(), and func.<something>(), renamed to "type_". those objects
+ still name their "type" attribute as "type".
- transactions:
- added context manager (with statement) support for transactions
- added support for two phase commit, works with mysql and postgres so far.
- MetaData:
- DynamicMetaData has been renamed to ThreadLocalMetaData
- BoundMetaData has been removed- regular MetaData is equivalent
+ - new SQL operator implementation which removes all hardcoded operators
+ from expression structures and moves them into compilation;
+ allows greater flexibility of operator compilation; for example, "+"
+ compiles to "||" when used in a string context, or "concat(a,b)" on
+ MySQL; whereas in a numeric context it compiles to "+". fixes [ticket:475].
- "anonymous" alias and label names are now generated at SQL compilation
time in a completely deterministic fashion...no more random hex IDs
- significant architectural overhaul to SQL elements (ClauseElement).
from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
from sqlalchemy.engine import default
-import string, re, sets, weakref, random
+import string, re, sets, random, operator
ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
BIND_PARAMS_ESC = re.compile(r'\x5c(:\w+)(?!:)', re.UNICODE)
+OPERATORS = {
+ operator.and_ : 'AND',
+ operator.or_ : 'OR',
+ operator.inv : 'NOT',
+ operator.add : '+',
+ operator.mul : '*',
+ operator.sub : '-',
+ operator.div : '/',
+ operator.mod : '%',
+ operator.truediv : '/',
+ operator.lt : '<',
+ operator.le : '<=',
+ operator.ne : '!=',
+ operator.gt : '>',
+ operator.ge : '>=',
+ operator.eq : '=',
+ sql.ColumnOperators.concat_op : '||',
+ sql.ColumnOperators.like_op : 'LIKE',
+ sql.ColumnOperators.notlike_op : 'NOT LIKE',
+ sql.ColumnOperators.ilike_op : 'ILIKE',
+ sql.ColumnOperators.notilike_op : 'NOT ILIKE',
+ sql.ColumnOperators.between_op : 'BETWEEN',
+ sql.ColumnOperators.in_op : 'IN',
+ sql.ColumnOperators.notin_op : 'NOT IN',
+ sql.ColumnOperators.comma_op : ', ',
+ sql.Operators.from_ : 'FROM',
+ sql.Operators.as_ : 'AS',
+ sql.Operators.exists : 'EXISTS',
+ sql.Operators.is_ : 'IS',
+ sql.Operators.isnot : 'IS NOT'
+}
+
class ANSIDialect(default.DefaultDialect):
def __init__(self, cache_identifiers=True, **kwargs):
super(ANSIDialect,self).__init__(**kwargs)
__traverse_options__ = {'column_collections':False, 'entry':True}
+ operators = OPERATORS
+
def __init__(self, dialect, statement, parameters=None, **kwargs):
"""Construct a new ``ANSICompiler`` object.
if isinstance(label.obj, sql._ColumnClause):
self.column_labels[label.obj._label] = labelname
self.column_labels[label.name] = labelname
- self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname)
+ self.strings[label] = " ".join([self.strings[label.obj], self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
def visit_column(self, column):
# there is actually somewhat of a ruleset when you would *not* necessarily
def visit_null(self, null):
self.strings[null] = 'NULL'
- def visit_clauselist(self, list):
- sep = list.operator
- if sep == ',':
- sep = ', '
- elif sep is None or sep == " ":
+ def visit_clauselist(self, clauselist):
+ sep = clauselist.operator
+ if sep is None:
sep = " "
+ elif sep == sql.ColumnOperators.comma_op:
+ sep = ', '
else:
- sep = " " + sep + " "
- self.strings[list] = string.join([s for s in [self.strings[c] for c in list.clauses] if s is not None], sep)
+ sep = " " + self.operator_string(clauselist.operator) + " "
+ self.strings[clauselist] = string.join([s for s in [self.strings[c] for c in clauselist.clauses] if s is not None], sep)
def apply_function_parens(self, func):
return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
def visit_unary(self, unary):
s = self.strings[unary.element]
if unary.operator:
- s = unary.operator + " " + s
+ s = self.operator_string(unary.operator) + " " + s
if unary.modifier:
s = s + " " + unary.modifier
self.strings[unary] = s
def visit_binary(self, binary):
- result = self.strings[binary.left]
- if binary.operator is not None:
- result += " " + self.binary_operator_string(binary)
- result += " " + self.strings[binary.right]
- self.strings[binary] = result
-
- def binary_operator_string(self, binary):
- return binary.operator
+ op = self.operator_string(binary.operator)
+ if callable(op):
+ self.strings[binary] = op(binary.left, binary.right)
+ else:
+ self.strings[binary] = self.strings[binary.left] + " " + op + " " + self.strings[binary.right]
+
+ def operator_string(self, operator):
+ return self.operators.get(operator, str(operator))
def visit_bindparam(self, bindparam):
# apply truncation to the ultimate generated name
" ON " + self.strings[join.onclause])
self.strings[join] = self.froms[join]
- def visit_insert_column_default(self, column, default, parameters):
- """Called when visiting an ``Insert`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object, add a blank *placeholder* parameter so the ``Insert``
- gets compiled with this column's name in its column and
- ``VALUES`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_update_column_default(self, column, default, parameters):
- """Called when visiting an ``Update`` statement.
-
- For each column in the table that contains a ``ColumnDefault``
- object as an onupdate, add a blank *placeholder* parameter so
- the ``Update`` gets compiled with this column's name as one of
- its ``SET`` clauses.
- """
-
- parameters.setdefault(column.key, None)
-
- def visit_insert_sequence(self, column, sequence, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden compilers that support sequences to
- place a blank *placeholder* parameter for each column in the
- table that contains a Sequence object, so the Insert gets
- compiled with this column's name in its column and ``VALUES``
- clauses.
- """
-
- pass
-
- def visit_insert_column(self, column, parameters):
- """Called when visiting an ``Insert`` statement.
-
- This may be overridden by compilers who disallow NULL columns
- being set in an ``Insert`` where there is a default value on
- the column (i.e. postgres), to remove the column for which
- there is a NULL insert from the parameter list.
- """
-
- pass
-
+ def uses_sequences_for_inserts(self):
+ return False
+
def visit_insert(self, insert_stmt):
- # scan the table's columns for defaults that have to be pre-set for an INSERT
- # add these columns to the parameter list via visit_insert_XXX methods
- default_params = {}
+
+ # search for columns who will be required to have an explicit bound value.
+ # for inserts, this includes Python-side defaults, columns with sequences for dialects
+ # that support sequences, and primary key columns for dialects that explicitly insert
+ # pre-generated primary key values
+ required_cols = util.Set()
class DefaultVisitor(schema.SchemaVisitor):
- def visit_column(s, c):
- self.visit_insert_column(c, default_params)
+ def visit_column(s, cd):
+ if c.primary_key and self.uses_sequences_for_inserts():
+ required_cols.add(c)
def visit_column_default(s, cd):
- self.visit_insert_column_default(c, cd, default_params)
+ required_cols.add(c)
def visit_sequence(s, seq):
- self.visit_insert_sequence(c, seq, default_params)
+ if self.uses_sequences_for_inserts():
+ required_cols.add(c)
vis = DefaultVisitor()
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isinsert = True
- colparams = self._get_colparams(insert_stmt, default_params)
-
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- if p.shortname is not None:
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.inline_params.add(col)
- self.traverse(p)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.strings[p] + ")"
- else:
- return self.strings[p]
+ colparams = self._get_colparams(insert_stmt, required_cols)
text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
- " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
+ " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
- # scan the table's columns for onupdates that have to be pre-set for an UPDATE
- # add these columns to the parameter list via visit_update_XXX methods
- default_params = {}
+
+ # search for columns who will be required to have an explicit bound value.
+ # for updates, this includes Python-side "onupdate" defaults.
+ required_cols = util.Set()
class OnUpdateVisitor(schema.SchemaVisitor):
def visit_column_onupdate(s, cd):
- self.visit_update_column_default(c, cd, default_params)
+ required_cols.add(c)
vis = OnUpdateVisitor()
for c in update_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
vis.traverse(c)
self.isupdate = True
- colparams = self._get_colparams(update_stmt, default_params)
-
- self.inline_params = util.Set()
- def create_param(col, p):
- if isinstance(p, sql._BindParamClause):
- self.binds[p.key] = p
- self.binds[p.shortname] = p
- return self.bindparam_string(self._truncate_bindparam(p))
- else:
- self.traverse(p)
- self.inline_params.add(col)
- if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
- return "(" + self.strings[p] + ")"
- else:
- return self.strings[p]
+ colparams = self._get_colparams(update_stmt, required_cols)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
+ text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
if update_stmt._whereclause:
text += " WHERE " + self.strings[update_stmt._whereclause]
self.strings[update_stmt] = text
-
- def _get_colparams(self, stmt, default_params):
- """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples.
-
- Each tuple will contain the ``Column`` and a ``ClauseElement``
- representing the value to be set (usually a ``_BindParamClause``,
- but could also be other SQL expressions.)
-
- The list of tuples will determine the columns that are
- actually rendered into the ``SET``/``VALUES`` clause of the
- rendered ``UPDATE``/``INSERT`` statement. It will also
- determine how to generate the list/dictionary of bind
- parameters at execution time (i.e. ``get_params()``).
-
- This list takes into account the `values` keyword specified
- to the statement, the parameters sent to this Compiled
- instance, and the default bind parameter values corresponding
- to the dialect's behavior for otherwise unspecified primary
- key columns.
+ def _get_colparams(self, stmt, required_cols):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ This method may generate new bind params within this compiled
+ based on the given set of "required columns", which are required
+ to have a value set in the statement.
"""
+ def create_bind_param(col, value):
+ bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True)
+ self.binds[col.key] = bindparam
+ return self.bindparam_string(self._truncate_bindparam(bindparam))
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.parameters is None and stmt.parameters is None:
- return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
+ return [(c, create_bind_param(c, None)) for c in stmt.table.columns]
+
+ def create_clause_param(col, value):
+ self.traverse(value)
+ self.inline_params.add(col)
+ return self.strings[value]
+
+ self.inline_params = util.Set()
def to_col(key):
if not isinstance(key, sql._ColumnClause):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(to_col(k), v)
- for k, v in default_params.iteritems():
- parameters.setdefault(to_col(k), v)
+ for col in required_cols:
+ parameters.setdefault(col, None)
# create a list of column assignment clauses as tuples
values = []
for c in stmt.table.columns:
- if parameters.has_key(c):
+ if c in parameters:
value = parameters[c]
if sql._is_literal(value):
- value = sql.bindparam(c.key, value, type=c.type, unique=True)
+ value = create_bind_param(c, value)
+ else:
+ value = create_clause_param(c, value)
values.append((c, value))
-
+
return values
def visit_delete(self, delete_stmt):
for column in table.columns:
if column.default is not None:
self.traverse_single(column.default)
- #if column.onupdate is not None:
- # column.onupdate.accept_visitor(visitor)
self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import re, datetime, inspect, warnings, weakref
+import re, datetime, inspect, warnings, weakref, operator
from sqlalchemy import sql, schema, ansisql
from sqlalchemy.engine import default
class MySQLCompiler(ansisql.ANSICompiler):
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ sql.ColumnOperators.concat_op : lambda x, y:"concat(%s, %s)" % (x, y),
+ operator.mod : '%%'
+ }
+ )
+
def visit_cast(self, cast):
if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
return super(MySQLCompiler, self).visit_cast(cast)
text += " OFFSET " + str(select._offset)
return text
- def binary_operator_string(self, binary):
- if binary.operator == '%':
- return '%%'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, re, warnings
+import sys, StringIO, string, re, warnings, operator
from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
from sqlalchemy.engine import default, base
the use_ansi flag is False.
"""
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : lambda x, y:"mod(%s, %s)" % (x, y)
+ }
+ )
+
def default_from(self):
"""Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
self.traverse_single(self.wheres[join])
- def visit_insert_sequence(self, column, sequence, parameters):
- """This is the `sequence` equivalent to ``ANSICompiler``'s
- `visit_insert_column_default` which ensures that the column is
- present in the generated column list.
- """
-
- parameters.setdefault(column.key, None)
+ def uses_sequences_for_inserts(self):
+ return True
def visit_alias(self, alias):
"""Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
else:
return super(OracleCompiler, self).for_update_clause(select)
- def visit_binary(self, binary):
- if binary.operator == '%':
- self.strings[binary] = ("MOD(%s,%s)"%(self.strings[binary.left], self.strings[binary.right]))
- else:
- return ansisql.ANSICompiler.visit_binary(self, binary)
-
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import datetime, string, types, re, random, warnings
+import datetime, string, types, re, random, warnings, operator
from sqlalchemy import util, sql, schema, ansisql, exceptions
from sqlalchemy.engine import base, default
def get_col_spec(self):
return "BOOLEAN"
-class PGArray(sqltypes.TypeEngine):
+class PGArray(sqltypes.TypeEngine, sqltypes.Concatenable):
def __init__(self, item_type):
if isinstance(item_type, type):
item_type = item_type()
ORDER BY a.attnum
""" % schema_where_clause
- s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type=sqltypes.Unicode), sql.bindparam('schema', type=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
+ s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode})
c = connection.execute(s, table_name=table.name,
schema=table.schema)
rows = c.fetchall()
class PGCompiler(ansisql.ANSICompiler):
- def visit_insert_column(self, column, parameters):
- # all column primary key inserts must be explicitly present
- if column.primary_key:
- parameters[column.key] = None
+ operators = ansisql.ANSICompiler.operators.copy()
+ operators.update(
+ {
+ operator.mod : '%%'
+ }
+ )
- def visit_insert_sequence(self, column, sequence, parameters):
- """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
- that the column is present in the generated column list"""
- parameters.setdefault(column.key, None)
+ def uses_sequences_for_inserts(self):
+ return True
def limit_clause(self, select):
text = ""
else:
return super(PGCompiler, self).for_update_clause(select)
- def binary_operator_string(self, binary):
- if isinstance(binary.type, (sqltypes.String, PGArray)) and binary.operator == '+':
- return '||'
- elif binary.operator == '%':
- return '%%'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
# sqlite has no "FOR UPDATE" AFAICT
return ''
- def binary_operator_string(self, binary):
- if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
- return '||'
- else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
statement.
"""
+ # TODO: this calculation of defaults is one of the places SA slows down inserts.
+ # look into optimizing this for a list of params where theres no defaults defined
+ # (i.e. analyze the first batch of params).
if self.compiled.isinsert:
if isinstance(self.compiled_parameters, list):
plist = self.compiled_parameters
self._lastrow_has_defaults = True
newid = drunner.get_column_default(c)
if newid is not None:
+ print "GOT GENERATED DEFAULT", c, repr(newid)
param.set_value(c.key, newid)
if c.primary_key:
last_inserted_ids.append(param.get_processed(c.key))
if len(secondary_delete):
secondary_delete.sort()
# TODO: precompile the delete/insert queries?
- statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type=c.type) for c in self.secondary.c if c.key in associationrow]))
+ statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
result = connection.execute(statement, secondary_delete)
if result.supports_sane_rowcount() and result.rowcount != len(secondary_delete):
raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (result.rowcount, len(secondary_delete)))
return operator(self.comparator, value)
-class PropComparator(sql.Comparator):
+class PropComparator(sql.ColumnOperators):
"""defines comparison operations for MapperProperty objects"""
def contains_op(a, b):
from sqlalchemy.orm.util import ExtensionCarrier
from sqlalchemy.orm import sync
from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, EXT_PASS, MapperExtension, SynonymProperty
-import weakref, warnings
+import weakref, warnings, operator
__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
_get_clause = sql.and_()
for primary_key in self.primary_key:
- _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
+ _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type_=primary_key.type, unique=True))
self._get_clause = _get_clause
def _get_equivalent_columns(self):
result = {}
def visit_binary(binary):
- if binary.operator == '=':
+ if binary.operator == operator.eq:
if binary.left in result:
result[binary.left].add(binary.right)
else:
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper.pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label, type=col.type, unique=True))
+ clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
if mapper.version_id_col is not None:
- clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type=col.type, unique=True))
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
statement = table.update(clause)
rows = 0
supports_sane_rowcount = True
delete.sort(comparator)
clause = sql.and_()
for col in mapper.pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col.key, type=col.type, unique=True))
+ clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
if mapper.version_id_col is not None:
- clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type=mapper.version_id_col.type, unique=True))
+ clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
statement = table.delete(clause)
c = connection.execute(statement, delete)
if c.supports_sane_rowcount() and c.rowcount != len(delete):
if leftcol is None or rightcol is None:
return
if leftcol.table not in needs_tables:
- binary.left = sql.bindparam(leftcol.name, None, type=binary.right.type, unique=True)
+ binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True)
param_names.append(leftcol)
elif rightcol not in needs_tables:
- binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
+ binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True)
param_names.append(rightcol)
cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
return cond, param_names
if len(self.foreign_keys):
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
if binary.left in self.foreign_keys:
self._opposite_side.add(binary.right)
self.foreign_keys = util.Set()
self._opposite_side = util.Set()
def visit_binary(binary):
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
# this check is for when the user put the "view_only" flag on and has tables that have nothing
if should_bind(leftcol, rightcol):
col = leftcol
binary.left = binds.setdefault(leftcol,
- sql.bindparam(None, None, shortname=leftcol.name, type=binary.right.type, unique=True))
+ sql.bindparam(None, None, shortname=leftcol.name, type_=binary.right.type, unique=True))
reverse[rightcol] = binds[col]
# the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
if leftcol is not rightcol and should_bind(rightcol, leftcol):
col = rightcol
binary.right = binds.setdefault(rightcol,
- sql.bindparam(None, None, shortname=rightcol.name, type=binary.left.type, unique=True))
+ sql.bindparam(None, None, shortname=rightcol.name, type_=binary.left.type, unique=True))
reverse[leftcol] = binds[col]
lazywhere = primaryjoin
from sqlalchemy import sql, schema, exceptions
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
+import operator
ONETOMANY = 0
MANYTOONE = 1
def compile_binary(binary):
"""Assemble a SyncRule given a single binary condition."""
- if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+ if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
source_column = None
``TableClause``/``Table``.
"""
- def __init__(self, name, type, *args, **kwargs):
+ def __init__(self, name, type_, *args, **kwargs):
"""Construct a new ``Column`` object.
Arguments are:
The name of this column. This should be the identical name
as it appears, or will appear, in the database.
- type
+ type_
The ``TypeEngine`` for this column. This can be any
subclass of ``types.AbstractType``, including the
database-agnostic types defined in the types module,
identifier contains mixed case.
"""
- super(Column, self).__init__(name, None, type)
+ super(Column, self).__init__(name, None, type_)
self.args = args
self.key = kwargs.pop('key', name)
self._primary_key = kwargs.pop('primary_key', False)
'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join',
'Select', 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc',
- 'between_', 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
+ 'between', 'bindparam', 'case', 'cast', 'column', 'delete',
'desc', 'distinct', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
'insert', 'intersect', 'intersect_all', 'join', 'literal',
'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
'subquery', 'table', 'text', 'union', 'union_all', 'update',]
-# precedence ordering for common operators. if an operator is not present in this list,
-# it will be parenthesized when grouped against other operators
-PRECEDENCE = {
- 'FROM':15,
- '*':7,
- '/':7,
- '%':7,
- '+':6,
- '-':6,
- 'ILIKE':5,
- 'NOT ILIKE':5,
- 'LIKE':5,
- 'NOT LIKE':5,
- 'IN':5,
- 'NOT IN':5,
- 'IS':5,
- 'IS NOT':5,
- '=':5,
- '!=':5,
- '>':5,
- '<':5,
- '>=':5,
- '<=':5,
- 'BETWEEN':5,
- 'NOT':4,
- 'AND':3,
- 'OR':2,
- ',':-1,
- 'AS':-1,
- 'EXISTS':0,
- '_smallest': -1000,
- '_largest': 1000
-}
BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
def desc(column):
"""
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='AND', *clauses)
+ return ClauseList(operator=operator.and_, *clauses)
def or_(*clauses):
"""Join a list of clauses together using the ``OR`` operator.
if len(clauses) == 1:
return clauses[0]
- return ClauseList(operator='OR', *clauses)
+ return ClauseList(operator=operator.or_, *clauses)
def not_(clause):
"""Return a negation of the given clause, i.e. ``NOT(clause)``.
subclasses to produce the same result.
"""
- return clause._negate()
+ return operator.inv(clause)
def distinct(expr):
"""return a ``DISTINCT`` clause."""
provides similar functionality.
"""
- return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type=ctest.type), _literal_as_binds(cright, type=ctest.type), operator='AND', group=False), 'BETWEEN')
+ return _BinaryExpression(ctest, ClauseList(_literal_as_binds(cleft, type_=ctest.type), _literal_as_binds(cright, type_=ctest.type), operator=operator.and_, group=False), ColumnOperators.between_op)
-def between_(*args, **kwargs):
- """synonym for [sqlalchemy.sql#between()] (deprecated)."""
-
- return between(*args, **kwargs)
def case(whens, value=None, else_=None):
"""Produce a ``CASE`` statement.
type = list(whenlist[-1])[-1].type
else:
type = None
- cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END'])
+ cc = _CalculatedClause(None, 'CASE', value, type_=type, operator=None, group_contents=False, *whenlist + ['END'])
return cc
def cast(clause, totype, **kwargs):
def extract(field, expr):
"""Return the clause ``extract(field FROM expr)``."""
- expr = _BinaryExpression(text(field), expr, "FROM")
+ expr = _BinaryExpression(text(field), expr, Operators.from_)
return func.extract(expr)
def exists(*args, **kwargs):
return Alias(selectable, alias=alias)
-def literal(value, type=None):
+def literal(value, type_=None):
"""Return a literal clause, bound to a bind parameter.
Literal clauses are created automatically when non-
"""
- return _BindParamClause('literal', value, type=type, unique=True)
+ return _BindParamClause('literal', value, type_=type_, unique=True)
def label(name, obj):
"""Return a [sqlalchemy.sql#_Label] object for the given [sqlalchemy.sql#ColumnElement].
return _Label(name, obj)
-def column(text, type=None):
+def column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
"""
- return _ColumnClause(text, type=type)
+ return _ColumnClause(text, type_=type_)
-def literal_column(text, type=None):
+def literal_column(text, type_=None):
"""Return a textual column clause, as would be in the columns
clause of a ``SELECT`` statement.
"""
- return _ColumnClause(text, type=type, is_literal=True)
+ return _ColumnClause(text, type_=type_, is_literal=True)
def table(name, *columns):
"""Return a [sqlalchemy.sql#Table] object.
return TableClause(name, *columns)
-def bindparam(key, value=None, type=None, shortname=None, unique=False):
+def bindparam(key, value=None, type_=None, shortname=None, unique=False):
"""Create a bind parameter clause with the given key.
value
"""
if isinstance(key, _ColumnClause):
- return _BindParamClause(key.name, value, type=key.type, shortname=shortname, unique=unique)
+ return _BindParamClause(key.name, value, type_=key.type, shortname=shortname, unique=unique)
else:
- return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
+ return _BindParamClause(key, value, type_=type_, shortname=shortname, unique=unique)
def text(text, bind=None, *args, **kwargs):
"""Create literal text to be inserted into a query.
return not isinstance(element, ClauseElement)
def _literal_as_text(element):
- if isinstance(element, Comparator):
+ if isinstance(element, Operators):
return element.clause_element()
elif _is_literal(element):
return _TextClause(unicode(element))
else:
return element
-def _literal_as_binds(element, name='literal', type=None):
+def _literal_as_binds(element, name='literal', type_=None):
if _is_literal(element):
if element is None:
return null()
else:
- return _BindParamClause(name, element, shortname=name, type=type, unique=True)
+ return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
else:
return element
if hasattr(self, 'negation_clause'):
return self.negation_clause
else:
- return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
+ return _UnaryExpression(self.self_group(against=operator.inv), operator=operator.inv, negate=None)
+
+
+class Operators(object):
+ def from_():
+ raise NotImplementedError()
+ from_ = staticmethod(from_)
+
+ def as_():
+ raise NotImplementedError()
+ as_ = staticmethod(as_)
+
+ def exists():
+ raise NotImplementedError()
+ exists = staticmethod(exists)
+
+ def is_():
+ raise NotImplementedError()
+ is_ = staticmethod(is_)
+
+ def isnot():
+ raise NotImplementedError()
+ isnot = staticmethod(isnot)
+
+ def __and__(self, other):
+ return self.operate(operator.and_, other)
+ def __or__(self, other):
+ return self.operate(operator.or_, other)
+
+ def __invert__(self):
+ return self.operate(operator.inv)
+
+ def clause_element(self):
+ raise NotImplementedError()
+
+ def operate(self, op, *other):
+ raise NotImplementedError()
+
+ def reverse_operate(self, op, *other):
+ raise NotImplementedError()
-class Comparator(object):
+class ColumnOperators(Operators):
"""defines comparison and math operations"""
def like_op(a, b):
return a.like(b)
like_op = staticmethod(like_op)
+ def notlike_op(a, b):
+ raise NotImplementedError()
+ notlike_op = staticmethod(notlike_op)
+
+ def ilike_op(a, b):
+ return a.ilike(b)
+ ilike_op = staticmethod(ilike_op)
+
+ def notilike_op(a, b):
+ raise NotImplementedError()
+ notilike_op = staticmethod(notilike_op)
+
def between_op(a, b):
return a.between(b)
between_op = staticmethod(between_op)
def in_op(a, b):
return a.in_(*b)
in_op = staticmethod(in_op)
+
+ def notin_op(a, b):
+ raise NotImplementedError()
+ notin_op = staticmethod(notin_op)
def startswith_op(a, b):
return a.startswith(b)
def endswith_op(a, b):
return a.endswith(b)
endswith_op = staticmethod(endswith_op)
-
- def clause_element(self):
- raise NotImplementedError()
-
- def operate(self, op, other):
- raise NotImplementedError()
- def reverse_operate(self, op, other):
+ def comma_op(a, b):
raise NotImplementedError()
+ comma_op = staticmethod(comma_op)
+
+ def concat_op(a, b):
+ return a.concat(b)
+ concat_op = staticmethod(concat_op)
def __lt__(self, other):
return self.operate(operator.lt, other)
def __ge__(self, other):
return self.operate(operator.ge, other)
+ def concat(self, other):
+ return self.operate(ColumnOperators.concat_op, other)
+
def like(self, other):
- return self.operate(Comparator.like_op, other)
-
+ return self.operate(ColumnOperators.like_op, other)
+
def in_(self, *other):
- return self.operate(Comparator.in_op, other)
-
+ return self.operate(ColumnOperators.in_op, other)
+
def startswith(self, other):
- return self.operate(Comparator.startswith_op, other)
+ return self.operate(ColumnOperators.startswith_op, other)
def endswith(self, other):
- return self.operate(Comparator.endswith_op, other)
+ return self.operate(ColumnOperators.endswith_op, other)
def __radd__(self, other):
return self.reverse_operate(operator.add, other)
return self.reverse_operate(operator.div, other)
def between(self, cleft, cright):
- return self.operate(Comparator.between_op, (cleft, cright))
+ return self.operate(Operators.between_op, (cleft, cright))
def __add__(self, other):
return self.operate(operator.add, other)
def __truediv__(self, other):
return self.operate(operator.truediv, other)
-class _CompareMixin(Comparator):
+# precedence ordering for common operators. if an operator is not present in this list,
+# it will be parenthesized when grouped against other operators
+_smallest = object()
+_largest = object()
+
+PRECEDENCE = {
+ Operators.from_:15,
+ operator.mul:7,
+ operator.div:7,
+ operator.mod:7,
+ operator.add:6,
+ operator.sub:6,
+ ColumnOperators.concat_op:6,
+ ColumnOperators.ilike_op:5,
+ ColumnOperators.notilike_op:5,
+ ColumnOperators.like_op:5,
+ ColumnOperators.notlike_op:5,
+ ColumnOperators.in_op:5,
+ ColumnOperators.notin_op:5,
+ Operators.is_:5,
+ Operators.isnot:5,
+ operator.eq:5,
+ operator.ne:5,
+ operator.gt:5,
+ operator.lt:5,
+ operator.ge:5,
+ operator.le:5,
+ ColumnOperators.between_op:5,
+ operator.inv:4,
+ operator.and_:3,
+ operator.or_:2,
+ ColumnOperators.comma_op:-1,
+ Operators.as_:-1,
+ Operators.exists:0,
+ _smallest: -1000,
+ _largest: 1000
+}
+
+class _CompareMixin(ColumnOperators):
"""Defines comparison and math operations for ``ClauseElement`` instances."""
- def __compare(self, operator, obj, negate=None):
+ def __compare(self, op, obj, negate=None):
if obj is None or isinstance(obj, _Null):
- if operator == '=':
- return _BinaryExpression(self.clause_element(), null(), 'IS', negate='IS NOT')
- elif operator == '!=':
- return _BinaryExpression(self.clause_element(), null(), 'IS NOT', negate='IS')
+ if op == operator.eq:
+ return _BinaryExpression(self.clause_element(), null(), Operators.is_, negate=Operators.isnot)
+ elif op == operator.ne:
+ return _BinaryExpression(self.clause_element(), null(), Operators.isnot, negate=Operators.is_)
else:
raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
else:
obj = self._check_literal(obj)
- return _BinaryExpression(self.clause_element(), obj, operator, type=sqltypes.Boolean, negate=negate)
+
+ return _BinaryExpression(self.clause_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
- def __operate(self, operator, obj):
+ def __operate(self, op, obj):
obj = self._check_literal(obj)
- return _BinaryExpression(self.clause_element(), obj, operator, type=self._compare_type(obj))
+
+ type_ = self._compare_type(obj)
+ if op == operator.add and isinstance(type_, (sqltypes.Concatenable)):
+ op = ColumnOperators.concat_op
+
+ return _BinaryExpression(self.clause_element(), obj, op, type_=type_)
operators = {
- operator.add : (__operate, '+'),
- operator.mul : (__operate, '*'),
- operator.sub : (__operate, '-'),
- operator.div : (__operate, '/'),
- operator.mod : (__operate, '%'),
- operator.truediv : (__operate, '/'),
- operator.lt : (__compare, '<', '=>'),
- operator.le : (__compare, '<=', '>'),
- operator.ne : (__compare, '!=', '='),
- operator.gt : (__compare, '>', '<='),
- operator.ge : (__compare, '>=', '<'),
- operator.eq : (__compare, '=', '!='),
- Comparator.like_op : (__compare, 'LIKE', 'NOT LIKE'),
+ operator.add : (__operate,),
+ operator.mul : (__operate,),
+ operator.sub : (__operate,),
+ operator.div : (__operate,),
+ operator.mod : (__operate,),
+ operator.truediv : (__operate,),
+ operator.lt : (__compare, operator.ge),
+ operator.le : (__compare, operator.gt),
+ operator.ne : (__compare, operator.eq),
+ operator.gt : (__compare, operator.le),
+ operator.ge : (__compare, operator.lt),
+ operator.eq : (__compare, operator.ne),
+ ColumnOperators.like_op : (__compare, ColumnOperators.notlike_op),
}
def operate(self, op, other):
o = _CompareMixin.operators[op]
- return o[0](self, o[1], other, *o[2:])
+ return o[0](self, op, other, *o[1:])
def reverse_operate(self, op, other):
return self._bind_param(other).operate(op, self)
def in_(self, *other):
- """produce an ``IN`` clause."""
+ return self._in_impl(ColumnOperators.in_op, ColumnOperators.notin_op, *other)
+
+ def _in_impl(self, op, negate_op, *other):
if len(other) == 0:
return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1')))
elif len(other) == 1:
return self.__eq__( o) #single item -> ==
else:
assert hasattr( o, '_selectable') #better check?
- return self.__compare( 'IN', o, negate='NOT IN') #single selectable
+ return self.__compare( op, o, negate=negate_op) #single selectable
args = []
for o in other:
else:
o = self._bind_param(o)
args.append(o)
- return self.__compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+ return self.__compare(op, ClauseList(*args).self_group(against=op), negate=negate_op)
def startswith(self, other):
"""produce the clause ``LIKE '<other>%'``"""
- perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
+
+ perc = isinstance(other,(str,unicode)) and '%' or literal('%',type_= sqltypes.String)
return self.__compare('LIKE', other + perc)
def endswith(self, other):
"""produce the clause ``LIKE '%<other>'``"""
+
if isinstance(other,(str,unicode)): po = '%' + other
else:
- po = literal('%', type= sqltypes.String) + other
- po.type = sqltypes.to_instance( sqltypes.String) #force!
+ po = literal('%', type_=sqltypes.String) + other
+ po.type = sqltypes.to_instance(sqltypes.String) #force!
return self.__compare('LIKE', po)
def label(self, name):
def between(self, cleft, cright):
"""produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
- return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator='AND', group=False), 'BETWEEN')
+ return _BinaryExpression(self, ClauseList(self._check_literal(cleft), self._check_literal(cright), operator=operator.and_, group=False), 'BETWEEN')
def op(self, operator):
"""produce a generic operator function.
return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
- return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
+ return _BindParamClause('literal', obj, shortname=None, type_=self.type, unique=True)
def _check_literal(self, other):
- if isinstance(other, Comparator):
+ if isinstance(other, Operators):
return other.clause_element()
elif _is_literal(other):
return self._bind_param(other)
__visit_name__ = 'bindparam'
- def __init__(self, key, value, shortname=None, type=None, unique=False):
+ def __init__(self, key, value, shortname=None, type_=None, unique=False):
"""Construct a _BindParamClause.
key
execution may match either the key or the shortname of the
corresponding ``_BindParamClause`` objects.
- type
+ type_
A ``TypeEngine`` object that will be used to pre-process the
value corresponding to this ``_BindParamClause`` at
execution time.
self.value = value
self.shortname = shortname or key
self.unique = unique
- self.type = sqltypes.to_instance(type)
-
+ type_ = sqltypes.to_instance(type_)
+ if isinstance(type_, sqltypes.NullType) and type(value) in _BindParamClause.type_map:
+ self.type = sqltypes.to_instance(_BindParamClause.type_map[type(value)])
+ else:
+ self.type = type_
+
+ # TODO: move to types module, obviously
+ type_map = {
+ str : sqltypes.String,
+ unicode : sqltypes.Unicode,
+ int : sqltypes.Integer,
+ float : sqltypes.Numeric
+ }
+
def _get_from_objects(self, **modifiers):
return []
return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
def __repr__(self):
- return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
+ return "_BindParamClause(%s, %s, type_=%s)" % (repr(self.key), repr(self.value), repr(self.type))
class _TypeClause(ClauseElement):
"""Handle a type keyword in a SQL statement.
def __init__(self, *clauses, **kwargs):
self.clauses = []
- self.operator = kwargs.pop('operator', ',')
+ self.operator = kwargs.pop('operator', ColumnOperators.comma_op)
self.group = kwargs.pop('group', True)
self.group_contents = kwargs.pop('group_contents', True)
- self.negate_operator = kwargs.pop('negate', None)
for c in clauses:
if c is None:
continue
def _copy_internals(self):
self.clauses = [clause._clone() for clause in self.clauses]
- def _negate(self):
- if hasattr(self, 'negation_clause'):
- return self.negation_clause
- elif self.negate_operator is None:
- return super(ClauseList, self)._negate()
- else:
- return ClauseList(operator=self.negate_operator, negate=self.operator, *(not_(c) for c in self.clauses))
-
def get_children(self, **kwargs):
return self.clauses
return f
def self_group(self, against=None):
- if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.group and self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
def __init__(self, name, *clauses, **kwargs):
self.name = name
- self.type = sqltypes.to_instance(kwargs.get('type', None))
+ self.type = sqltypes.to_instance(kwargs.get('type_', None))
self._bind = kwargs.get('bind', None)
self.group = kwargs.pop('group', True)
self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
return self.clauses._get_from_objects(**modifiers)
def _bind_param(self, obj):
- return _BindParamClause(self.name, obj, type=self.type, unique=True)
+ return _BindParamClause(self.name, obj, type_=self.type, unique=True)
def select(self):
return select([self])
"""
def __init__(self, name, *clauses, **kwargs):
- self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
- kwargs['operator'] = ','
- self._bind = kwargs.get('bind', None)
+ kwargs['operator'] = ColumnOperators.comma_op
_CalculatedClause.__init__(self, name, **kwargs)
for c in clauses:
self.append(c)
def _make_proxy(self, selectable, name=None):
if name is not None:
- co = _ColumnClause(name, selectable, type=self.type)
+ co = _ColumnClause(name, selectable, type_=self.type)
co._distance = self._distance + 1
co.orig_set = self.orig_set
selectable.columns[name]= co
class _UnaryExpression(ColumnElement):
- def __init__(self, element, operator=None, modifier=None, type=None, negate=None):
+ def __init__(self, element, operator=None, modifier=None, type_=None, negate=None):
self.operator = operator
self.modifier = modifier
self.element = _literal_as_text(element).self_group(against=self.operator or self.modifier)
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
def _get_from_objects(self, **modifiers):
def _negate(self):
if self.negate is not None:
- return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
+ return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type_=self.type)
else:
return super(_UnaryExpression, self)._negate()
def self_group(self, against):
- if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+ if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest]):
return _Grouping(self)
else:
return self
class _BinaryExpression(ColumnElement):
"""Represent an expression that is ``LEFT <operator> RIGHT``."""
- def __init__(self, left, right, operator, type=None, negate=None):
+ def __init__(self, left, right, operator, type_=None, negate=None):
self.left = _literal_as_text(left).self_group(against=operator)
self.right = _literal_as_text(right).self_group(against=operator)
self.operator = operator
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self.negate = negate
def _get_from_objects(self, **modifiers):
(
self.left.compare(other.left) and self.right.compare(other.right)
or (
- self.operator in ['=', '!=', '+', '*'] and
+ self.operator in [operator.eq, operator.ne, operator.add, operator.mul] and
self.left.compare(other.right) and self.right.compare(other.left)
)
)
def self_group(self, against=None):
# use small/large defaults for comparison so that unknown operators are always parenthesized
- if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])):
+ if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE[_smallest]) <= PRECEDENCE.get(against, PRECEDENCE[_largest])):
return _Grouping(self)
else:
return self
def _negate(self):
if self.negate is not None:
- return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
+ return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type_=self.type)
else:
return super(_BinaryExpression, self)._negate()
def __init__(self, *args, **kwargs):
kwargs['correlate'] = True
s = select(*args, **kwargs).self_group()
- _UnaryExpression.__init__(self, s, operator="EXISTS")
+ _UnaryExpression.__init__(self, s, operator=Operators.exists)
def _hide_froms(self, **modifiers):
return self._get_from_objects(**modifiers)
class BinaryVisitor(ClauseVisitor):
def visit_binary(self, binary):
- if binary.operator == '=':
+ if binary.operator == operator.eq:
add_equiv(binary.left, binary.right)
BinaryVisitor().traverse(self.onclause)
equivs = util.Set()
class LocateEquivs(NoColumnVisitor):
def visit_binary(self, binary):
- if binary.operator == '=' and binary.left.name == binary.right.name:
+ if binary.operator == operator.eq and binary.left.name == binary.right.name:
equivs.add(binary.right)
equivs.add(binary.left)
LocateEquivs().traverse(self.onclause)
"""
- def __init__(self, name, obj, type=None):
+ def __init__(self, name, obj, type_=None):
while isinstance(obj, _Label):
obj = obj.obj
self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
- self.obj = obj.self_group(against='AS')
+ self.obj = obj.self_group(against=Operators.as_)
self.case_sensitive = getattr(obj, "case_sensitive", True)
- self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
+ self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
key = property(lambda s: s.name)
_label = property(lambda s: s.name)
"""
- def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
+ def __init__(self, text, selectable=None, type_=None, _is_oid=False, case_sensitive=True, is_literal=False):
self.key = self.name = text
self.encodedname = isinstance(self.name, unicode) and self.name.encode('ascii', 'backslashreplace') or self.name
self.table = selectable
- self.type = sqltypes.to_instance(type)
+ self.type = sqltypes.to_instance(type_)
self._is_oid = _is_oid
self._distance = 0
self.__label = None
return []
def _bind_param(self, obj):
- return _BindParamClause(self._label, obj, shortname = self.name, type=self.type, unique=True)
+ return _BindParamClause(self._label, obj, shortname=self.name, type_=self.type, unique=True)
def _make_proxy(self, selectable, name = None):
# propigate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
- c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type=self.type, is_literal=is_literal)
+ c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
c.orig_set = self.orig_set
c._distance = self._distance + 1
if not self._is_oid:
column = literal_column(str(column))
if isinstance(column, Select) and column.is_scalar:
- column = column.self_group(against=',')
+ column = column.self_group(against=ColumnOperators.comma_op)
self._raw_columns.append(column)
for key in parameters.keys():
value = parameters[key]
if isinstance(value, ClauseElement):
- pass
+ parameters[key] = value.self_group()
elif _is_literal(value):
if _is_literal(key):
col = self.table.c[key]
return value
NullTypeEngine = NullType
-class String(TypeEngine):
+class Concatenable(object):
+ """marks a type as supporting 'concatenation'"""
+ pass
+
+class String(TypeEngine, Concatenable):
def __init__(self, length=None, convert_unicode=False):
self.length = length
self.convert_unicode = convert_unicode
assert False
except exceptions.InvalidRequestError, e:
assert str(e) == "This Compiled object is not bound to any Engine or Connection."
-
+
finally:
- bind.close()
+ if isinstance(bind, engine.Connection):
+ bind.close()
metadata.drop_all(bind=testbase.db)
def test_session(self):
mapper(Foo, table)
metadata.create_all(bind=testbase.db)
try:
- for bind in (testbase.db, testbase.db.connect()):
+ for bind in (testbase.db,
+ testbase.db.connect()
+ ):
for args in ({'bind':bind},):
sess = create_session(**args)
assert sess.bind is bind
sess.save(f)
sess.flush()
assert sess.get(Foo, f.foo) is f
+
+ if isinstance(bind, engine.Connection):
+ bind.close()
sess = create_session()
f = Foo()
assert False
except exceptions.InvalidRequestError, e:
assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
+
finally:
- bind.close()
+ if isinstance(bind, engine.Connection):
+ bind.close()
metadata.drop_all(bind=testbase.db)
except TypeError:
assert True
- e = create_engine('sqlite://', echo=True)
e = create_engine('mysql://', module=MockDBAPI(), connect_args={'use_unicode':True}, convert_unicode=True)
e = create_engine('sqlite://', connect_args={'use_unicode':True}, convert_unicode=True)
def define_tables(self, metadata):
global table_Employee, table_Engineer, table_Manager
table_Employee = Table( 'Employee', metadata,
- Column( 'name', type= String(100), ),
- Column( 'id', primary_key= True, type= Integer, ),
- Column( 'atype', type= String(100), ),
+ Column( 'name', type_= String(100), ),
+ Column( 'id', primary_key= True, type_= Integer, ),
+ Column( 'atype', type_= String(100), ),
)
table_Engineer = Table( 'Engineer', metadata,
- Column( 'machine', type= String(100), ),
+ Column( 'machine', type_= String(100), ),
Column( 'id', Integer, ForeignKey( 'Employee.id', ), primary_key= True, ),
)
table_Manager = Table( 'Manager', metadata,
- Column( 'duties', type= String(100), ),
+ Column( 'duties', type_= String(100), ),
Column( 'id', Integer, ForeignKey( 'Engineer.id', ), primary_key= True, ),
)
def test_threelevels(self):
(operator.sub, '-'), (operator.div, '/'),
):
for (lhs, rhs, res) in (
- ('a', User.id, ':users_id %s users.id'),
- ('a', literal('b'), ':literal %s :literal_1'),
- (User.id, 'b', 'users.id %s :users_id'),
+ (5, User.id, ':users_id %s users.id'),
+ (5, literal(6), ':literal %s :literal_1'),
+ (User.id, 5, 'users.id %s :users_id'),
(User.id, literal('b'), 'users.id %s :literal'),
(User.id, User.id, 'users.id %s users.id'),
- (literal('a'), 'b', ':literal %s :literal_1'),
- (literal('a'), User.id, ':literal %s users.id'),
- (literal('a'), literal('b'), ':literal %s :literal_1'),
+ (literal(5), 'b', ':literal %s :literal_1'),
+ (literal(5), User.id, ':literal %s users.id'),
+ (literal(5), literal(6), ':literal %s :literal_1'),
):
self._test(py_op(lhs, rhs), res % sql_op)
l = q.add_column("count").from_statement(s).all()
assert l == expected
- @testbase.unsupported('mysql') # only because of "+" operator requiring "concat" in mysql (fix #475)
def test_two_columns(self):
sess = create_session()
(user7, user8, user9, user10) = sess.query(User).all()
def testcase(self):
inner = select([case([
[info_table.c.pk < 3,
- literal('lessthan3', type=String)],
+ literal('lessthan3', type_=String)],
[and_(info_table.c.pk >= 3, info_table.c.pk < 7),
- literal('gt3', type=String)]]).label('x'),
+ literal('gt3', type_=String)]]).label('x'),
info_table.c.pk, info_table.c.info],
from_obj=[info_table]).alias('q_inner')
w_else = select([case([
[info_table.c.pk < 3,
- literal(3, type=Integer)],
+ literal(3, type_=Integer)],
[and_(info_table.c.pk >= 3, info_table.c.pk < 6),
- literal(6, type=Integer)]],
+ literal(6, type_=Integer)]],
else_ = 0).label('x'),
info_table.c.pk, info_table.c.info],
from_obj=[info_table]).alias('q_inner')
# select "count(1)" returns different results on different DBs
# also correct for "current_date" compatible as column default, value differences
- currenttime = func.current_date(type=Date, bind=db);
+ currenttime = func.current_date(type_=Date, bind=db);
if is_oracle:
ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar()
f = select([func.count(1) + 5], bind=db).scalar()
)
sometable = Table( 'Manager', metadata,
Column( 'obj_id', Integer, Sequence('obj_id_seq'), ),
- Column( 'name', type= String, ),
+ Column( 'name', String, ),
Column( 'id', Integer, primary_key= True, ),
)
y = testbase.db.func.current_date().select().execute().scalar()
z = testbase.db.func.current_date().scalar()
assert x == y == z
+
+ x = testbase.db.func.current_date(type_=Date)
+ assert isinstance(x.type, Date)
+ assert isinstance(x.execute().scalar(), datetime.date)
def test_conn_functions(self):
conn = testbase.db.connect()
w = select(['*'], from_obj=[testbase.db.func.current_date()]).scalar()
# construct a column-based FROM object out of a function, like in [ticket:172]
- s = select([column('date', type=DateTime)], from_obj=[testbase.db.func.current_date()])
+ s = select([column('date', type_=DateTime)], from_obj=[testbase.db.func.current_date()])
q = s.execute().fetchone()[s.c.date]
r = s.alias('datequery').select().scalar()
# so SQLAlchemy's SQL construction engine can be used with no database dependencies at all.
table1 = table('mytable',
- column('myid'),
- column('name'),
- column('description'),
+ column('myid', Integer),
+ column('name', String),
+ column('description', String),
)
table2 = table(
'myothertable',
- column('otherid'),
- column('othername'),
+ column('otherid', Integer),
+ column('othername', String),
)
table3 = table(
'thirdtable',
- column('userid'),
- column('otherstuff'),
+ column('userid', Integer),
+ column('otherstuff', String),
)
metadata = MetaData()
(operator.sub, '-'), (operator.div, '/'),
):
for (lhs, rhs, res) in (
- ('a', table1.c.myid, ':mytable_myid %s mytable.myid'),
- ('a', literal('b'), ':literal %s :literal_1'),
+ (5, table1.c.myid, ':mytable_myid %s mytable.myid'),
+ (5, literal(5), ':literal %s :literal_1'),
(table1.c.myid, 'b', 'mytable.myid %s :mytable_myid'),
- (table1.c.myid, literal('b'), 'mytable.myid %s :literal'),
+ (table1.c.myid, literal(2.7), 'mytable.myid %s :literal'),
(table1.c.myid, table1.c.myid, 'mytable.myid %s mytable.myid'),
- (literal('a'), 'b', ':literal %s :literal_1'),
- (literal('a'), table1.c.myid, ':literal %s mytable.myid'),
- (literal('a'), literal('b'), ':literal %s :literal_1'),
+ (literal(5), 8, ':literal %s :literal_1'),
+ (literal(6), table1.c.myid, ':literal %s mytable.myid'),
+ (literal(7), literal(5.5), ':literal %s :literal_1'),
):
self.runtest(py_op(lhs, rhs), res % sql_op)
)
self.runtest(
- literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
+ literal("a") + literal("b") * literal("c"), ":literal || :literal_1 * :literal_2"
)
# test the op() function, also that its results are further usable in expressions
def testliteral(self):
self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]),
- "SELECT :literal + :literal_1 FROM mytable")
+ "SELECT :literal || :literal_1 FROM mytable")
def testcalculatedcolumns(self):
value_tbl = table('values',
self.runtest(select([table1], table1.c.myid.in_('a', literal('b'))),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)")
- self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')),
+ self.runtest(select([table1], table1.c.myid.in_(literal(1) + 'a')),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1")
self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)")
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :mytable_myid)")
self.runtest(select([table1], table1.c.myid.in_(literal('a') + literal('a'), literal('b'))),
- "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :literal_2)")
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal || :literal_1, :literal_2)")
- self.runtest(select([table1], table1.c.myid.in_('a', literal('b') +'b')),
+ self.runtest(select([table1], table1.c.myid.in_(1, literal(3) + 4)),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal + :literal_1)")
self.runtest(select([table1], table1.c.myid.in_(literal('a') < 'b')),
self.runtest(select([table1], table1.c.myid.in_(literal('a'), table1.c.myid +'a')),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, mytable.myid + :mytable_myid)")
- self.runtest(select([table1], table1.c.myid.in_(literal('a'), 'a' + table1.c.myid)),
+ self.runtest(select([table1], table1.c.myid.in_(literal(1), 'a' + table1.c.myid)),
"SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal, :mytable_myid + mytable.myid)")
self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)),
values = {
table1.c.name : table1.c.name + "lala",
table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
- }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1")
+ }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=(mytable.name || :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal || mytable.name || :literal_1")
def testcorrelatedupdate(self):
# test against a straight text subquery
- u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")})
+ u = update(table1, values = {table1.c.name : text("(select name from mytable where id=mytable.id)")})
self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
+
+ mt = table1.alias()
+ u = update(table1, values = {table1.c.name : select([mt.c.name], mt.c.myid==table1.c.myid)})
+ self.runtest(u, "UPDATE mytable SET name=(SELECT mytable_1.name FROM mytable AS mytable_1 WHERE mytable_1.myid = mytable.myid)")
# test against a regular constructed subquery
s = select([table2], table2.c.otherid == table1.c.myid)