def schemadropper(self, proxy, **params):
return ANSISchemaDropper(proxy, **params)
- def compiler(self, statement, bindparams, **kwargs):
- return ANSICompiler(self, statement, bindparams, **kwargs)
+ def compiler(self, statement, parameters, **kwargs):
+ return ANSICompiler(self, statement, parameters, **kwargs)
def connect_args(self):
return ([],{})
return None
class ANSICompiler(sql.Compiled):
- def __init__(self, engine, statement, bindparams, typemap=None, paramstyle=None,**kwargs):
- sql.Compiled.__init__(self, engine, statement, bindparams)
+ """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
+ def __init__(self, engine, statement, parameters=None, typemap=None, **kwargs):
+ """constructs a new ANSICompiler object.
+
+ engine - SQLEngine to compile against
+
+ statement - ClauseElement to be compiled
+
+ parameters - optional dictionary indicating a set of bind parameters
+ specified with this Compiled object. These parameters are the "default"
+ key/value pairs when the Compiled is executed, and also may affect the
+ actual compilation, as in the case of an INSERT where the actual columns
+ inserted will correspond to the keys present in the parameters."""
+ sql.Compiled.__init__(self, engine, statement, parameters)
self.binds = {}
self.froms = {}
self.wheres = {}
self.typemap = typemap or {}
self.isinsert = False
- if paramstyle is None:
- db = self.engine.dbapi()
- if db is not None:
- paramstyle = db.paramstyle
- else:
- paramstyle = 'named'
-
- if paramstyle == 'named':
- self.bindtemplate = ':%s'
- self.positional=False
- elif paramstyle =='pyformat':
- self.bindtemplate = "%%(%s)s"
- self.positional=False
- else:
- # for positional, use pyformat until the end
- self.bindtemplate = "%%(%s)s"
- self.positional=True
- self.paramstyle=paramstyle
-
def after_compile(self):
- if self.positional:
+ if self.engine.positional:
self.positiontup = []
match = r'%\(([\w_]+)\)s'
params = re.finditer(match, self.strings[self.statement])
for p in params:
self.positiontup.append(p.group(1))
- if self.paramstyle=='qmark':
+ if self.engine.paramstyle=='qmark':
self.strings[self.statement] = re.sub(match, '?', self.strings[self.statement])
- elif self.paramstyle=='format':
+ elif self.engine.paramstyle=='format':
self.strings[self.statement] = re.sub(match, '%s', self.strings[self.statement])
- elif self.paramstyle=='numeric':
+ elif self.engine.paramstyle=='numeric':
i = 0
def getnum(x):
i += 1
for an executemany style of call, this method should be called for each element
in the list of parameter groups that will ultimately be executed.
"""
- d = {}
- if self.bindparams is not None:
- bindparams = self.bindparams.copy()
+ if self.parameters is not None:
+ bindparams = self.parameters.copy()
else:
bindparams = {}
bindparams.update(params)
- # TODO: cant we make "d" an ordereddict and add params in
- # positional order
+
+ if self.engine.positional:
+ d = OrderedDict()
+ for k in self.positiontup:
+ b = self.binds[k]
+ d[k] = b.typeprocess(b.value)
+ else:
+ d = {}
+ for b in self.binds.values():
+ d[b.key] = b.typeprocess(b.value)
+
for key, value in bindparams.iteritems():
try:
b = self.binds[key]
continue
d[b.key] = b.typeprocess(value)
- for b in self.binds.values():
- d.setdefault(b.key, b.typeprocess(b.value))
-
- if self.positional:
- return [d[key] for key in self.positiontup]
+ return d
+ if self.engine.positional:
+ return d.values()
else:
return d
same dictionary. For a positional paramstyle, the given parameters are
assumed to be in list format and are converted back to a dictionary.
"""
- if self.positional:
+# return parameters
+ if self.engine.positional:
p = {}
for i in range(0, len(self.positiontup)):
p[self.positiontup[i]] = parameters[i]
self.strings[bindparam] = self.bindparam_string(key)
def bindparam_string(self, name):
- return self.bindtemplate % name
+ return self.engine.bindtemplate % name
def visit_alias(self, alias):
self.froms[alias] = self.get_from_text(alias.selectable) + " AS " + alias.name
text = "SELECT "
if select.distinct:
text += "DISTINCT "
- text += collist + " \nFROM "
+ text += collist
whereclause = select.whereclause
t = self.get_from_text(f)
if t is not None:
froms.append(t)
-
- text += string.join(froms, ', ')
+
+ if len(froms):
+ text += " \nFROM "
+ text += string.join(froms, ', ')
if whereclause is not None:
t = self.get_str(whereclause)
self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
" ON " + self.get_str(join.onclause))
self.strings[join] = self.froms[join]
+
+ def visit_insert_column_default(self, column, default):
+ """called when visiting an Insert statement, for each column in the table that
+ contains a ColumnDefault object."""
+ self.parameters.setdefault(column.key, None)
+
+ def visit_insert_sequence(self, column, sequence):
+ """called when visiting an Insert statement, for each column in the table that
+ contains a Sequence object."""
+ pass
def visit_insert(self, insert_stmt):
+ # set up a call for the defaults and sequences inside the table
+ class DefaultVisitor(schema.SchemaVisitor):
+ def visit_column_default(s, cd):
+ self.visit_insert_column_default(c, cd)
+ def visit_sequence(s, seq):
+ self.visit_insert_sequence(c, seq)
+ vis = DefaultVisitor()
+ for c in insert_stmt.table.c:
+ if self.parameters.get(c.key, None) is None and c.default is not None:
+ c.default.accept_visitor(vis)
+
self.isinsert = True
- colparams = insert_stmt.get_colparams(self.bindparams)
+ colparams = insert_stmt.get_colparams(self.parameters)
for c in colparams:
b = c[1]
self.binds[b.key] = b
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
- colparams = update_stmt.get_colparams(self.bindparams)
+ colparams = update_stmt.get_colparams(self.parameters)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p
def last_inserted_ids(self):
return self.context.last_inserted_ids
- def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
- if compiled is None: return
+ def post_exec(self, proxy, compiled, parameters, **kwargs):
if getattr(compiled, "isinsert", False):
self.context.last_inserted_ids = [proxy().lastrowid]
def last_inserted_ids(self):
return self.context.last_inserted_ids
- def pre_exec(self, proxy, statement, parameters, compiled=None, **kwargs):
- if compiled is None: return
+ def pre_exec(self, proxy, compiled, parameters, **kwargs):
# this is just an assertion that all the primary key columns in an insert statement
# have a value set up, or have a default generator ready to go
if getattr(compiled, "isinsert", False):
return PGSchemaDropper(proxy, **params)
def defaultrunner(self, proxy):
- return PGDefaultRunner(proxy)
+ return PGDefaultRunner(self, proxy)
def get_default_schema_name(self):
if not hasattr(self, '_default_schema_name'):
def pre_exec(self, proxy, statement, parameters, **kwargs):
return
- def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
- if compiled is None: return
+ def post_exec(self, proxy, compiled, parameters, **kwargs):
if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
table = compiled.statement.table
cursor = proxy()
ischema.reflecttable(self, table, ischema_names)
class PGCompiler(ansisql.ANSICompiler):
- def visit_insert(self, insert):
- """inserts are required to have the primary keys be explicitly present.
- mapper will by default not put them in the insert statement to comply
- with autoincrement fields that require they not be present. so,
- put them all in for columns where sequence usage is defined."""
- for c in insert.table.primary_key:
- if self.bindparams.get(c.key, None) is None and c.default is not None and not c.default.optional:
- self.bindparams[c.key] = None
- return ansisql.ANSICompiler.visit_insert(self, insert)
+
+ def visit_insert_sequence(self, column, sequence):
+ if self.parameters.get(column.key, None) is None and not sequence.optional:
+ self.parameters[column.key] = None
def limit_clause(self, select):
text = ""
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or column.default.optional):
+ if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
params['poolclass'] = sqlalchemy.pool.SingletonThreadPool
ansisql.ANSISQLEngine.__init__(self, **params)
- def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
- if compiled is None: return
+ def post_exec(self, proxy, compiled, parameters, **kwargs):
if getattr(compiled, "isinsert", False):
self.context.last_inserted_ids = [proxy().lastrowid]
self.buffer.truncate(0)
class DefaultRunner(schema.SchemaVisitor):
- def __init__(self, proxy):
+ def __init__(self, engine, proxy):
self.proxy = proxy
+ self.engine = engine
def visit_sequence(self, seq):
"""sequences are not supported by default"""
return None
+ def exec_default_sql(self, default):
+ c = sql.select([default.arg], engine=self.engine).compile()
+ return self.proxy(str(c), c.get_params()).fetchone()[0]
+
def visit_column_default(self, default):
- if isinstance(default.arg, ClauseElement):
- c = default.arg.compile()
- return proxy.execute(str(c), c.get_params())
+ if isinstance(default.arg, sql.ClauseElement):
+ return self.exec_default_sql(default)
elif callable(default.arg):
return default.arg()
else:
self.context = util.ThreadLocal(raiseerror=False)
self.tables = {}
self.notes = {}
+ self._figure_paramstyle()
if logger is None:
self.logger = sys.stdout
else:
self.logger = logger
-
+
+ def _figure_paramstyle(self):
+ db = self.dbapi()
+ if db is not None:
+ self.paramstyle = db.paramstyle
+ else:
+ self.paramstyle = 'named'
+
+ if self.paramstyle == 'named':
+ self.bindtemplate = ':%s'
+ self.positional=False
+ elif self.paramstyle =='pyformat':
+ self.bindtemplate = "%%(%s)s"
+ self.positional=False
+ else:
+ # for positional, use pyformat until the end
+ self.bindtemplate = "%%(%s)s"
+ self.positional=True
def type_descriptor(self, typeobj):
if type(typeobj) is type:
raise NotImplementedError()
def defaultrunner(self, proxy):
- return DefaultRunner(proxy)
+ return DefaultRunner(self, proxy)
- def compiler(self, statement, bindparams):
+ def compiler(self, statement, parameters):
raise NotImplementedError()
def rowid_column_name(self):
"""drops a table given a schema.Table object."""
table.accept_visitor(self.schemadropper(self.proxy(), **params))
- def compile(self, statement, bindparams, **kwargs):
+ def compile(self, statement, parameters, **kwargs):
"""given a sql.ClauseElement statement plus optional bind parameters, creates a new
instance of this engine's SQLCompiler, compiles the ClauseElement, and returns the
newly compiled object."""
- compiler = self.compiler(statement, bindparams, **kwargs)
+ compiler = self.compiler(statement, parameters, **kwargs)
statement.accept_visitor(compiler)
compiler.after_compile()
return compiler
self.context.transaction = None
self.context.tcount = None
- def _process_defaults(self, proxy, statement, parameters, compiled=None, **kwargs):
+ def _process_defaults(self, proxy, compiled, parameters, **kwargs):
if compiled is None: return
if getattr(compiled, "isinsert", False):
- # TODO: this sucks. we have to get "parameters" to be a more standardized object
- if isinstance(parameters, list) and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
+ if isinstance(parameters, list):
plist = parameters
else:
plist = [parameters]
- # inserts are usually one at a time. but if we got a list of parameters,
- # it will calculate last_inserted_ids for just the last row in the list.
- # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
- # it or post-select anyway
drunner = self.defaultrunner(proxy)
for param in plist:
- # the parameters might be positional, so convert them
- # to a dictionary
- # TODO: this is stupid. or, is this stupid ?
- # any way we can just have an OrderedDict so we have the
- # dictionary + postional version each time ?
- param = compiled.get_named_params(param)
last_inserted_ids = []
need_lastrowid=False
for c in compiled.statement.table.c:
self.context.last_inserted_ids = last_inserted_ids
- def pre_exec(self, proxy, statement, parameters, **kwargs):
+ def pre_exec(self, proxy, compiled, parameters, **kwargs):
pass
- def post_exec(self, proxy, statement, parameters, **kwargs):
+ def post_exec(self, proxy, compiled, parameters, **kwargs):
pass
- def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs):
+ def execute_compiled(self, compiled, parameters, connection=None, cursor=None, echo=None, **kwargs):
"""executes the given string-based SQL statement with the given parameters.
The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
on the paramstyle of the DBAPI.
-
+
If the current thread has specified a transaction begin() for this engine, the
statement will be executed in the context of the current transactional connection.
Otherwise, a commit() will be performed immediately after execution, since the local
def proxy(statement=None, parameters=None):
if statement is None:
return cursor
+
+ executemany = parameters is not None and isinstance(parameters, list)
+
+ if self.positional:
+ if executemany:
+ parameters = [p.values() for p in parameters]
+ else:
+ parameters = parameters.values()
+
+ self.execute(statement, parameters, connection=connection, cursor=cursor)
+ return cursor
+
+ self.pre_exec(proxy, compiled, parameters, **kwargs)
+ self._process_defaults(proxy, compiled, parameters, **kwargs)
+ proxy(str(compiled), parameters)
+ self.post_exec(proxy, compiled, parameters, **kwargs)
+ return ResultProxy(cursor, self, typemap=compiled.typemap)
+
+ def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs):
+ """executes the given string-based SQL statement with the given parameters.
+
+ The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
+ on the paramstyle of the DBAPI.
+
+ If the current thread has specified a transaction begin() for this engine, the
+ statement will be executed in the context of the current transactional connection.
+ Otherwise, a commit() will be performed immediately after execution, since the local
+ pooled connection is returned to the pool after execution without a transaction set
+ up.
+
+ In all error cases, a rollback() is immediately performed on the connection before
+ propigating the exception outwards.
+
+ Other options include:
+
+ connection - a DBAPI connection to use for the execute. If None, a connection is
+ pulled from this engine's connection pool.
+
+ echo - enables echo for this execution, which causes all SQL and parameters
+ to be dumped to the engine's logging output before execution.
+
+ typemap - a map of column names mapped to sqlalchemy.types.TypeEngine objects.
+ These will be passed to the created ResultProxy to perform
+ post-processing on result-set values.
+
+ commit - if True, will automatically commit the statement after completion. """
+ if parameters is None:
+ parameters = {}
+
+ if connection is None:
+ connection = self.connection()
+
+ if cursor is None:
+ cursor = connection.cursor()
+
+ try:
if echo is True or self.echo is not False:
self.log(statement)
self.log(repr(parameters))
self._executemany(cursor, statement, parameters)
else:
self._execute(cursor, statement, parameters)
- return cursor
-
- try:
- self.pre_exec(proxy, statement, parameters, **kwargs)
- self._process_defaults(proxy, statement, parameters, **kwargs)
- proxy(statement, parameters)
- self.post_exec(proxy, statement, parameters, **kwargs)
- if commit or self.context.transaction is None:
+ if self.context.transaction is None:
self.do_commit(connection)
except:
self.do_rollback(connection)
- # TODO: wrap DB exceptions ?
raise
return ResultProxy(cursor, self, typemap = typemap)
self._impl = self.table.engine.columnimpl(self)
if self.default is not None:
+ self.default = ColumnDefault(self.default)
self._init_items(self.default)
self._init_items(*self.args)
self.args = None
object be dependent on the actual values of those bind parameters, even though it may
reference those values as defaults."""
- def __init__(self, engine, statement, bindparams):
+ def __init__(self, engine, statement, parameters):
+ """constructs a new Compiled object.
+
+ engine - SQLEngine to compile against
+
+ statement - ClauseElement to be compiled
+
+ parameters - optional dictionary indicating a set of bind parameters
+ specified with this Compiled object. These parameters are the "default"
+ values corresponding to the ClauseElement's BindParamClauses when the Compiled
+ is executed. In the case of an INSERT or UPDATE statement, these parameters
+ will also result in the creation of new BindParamClause objects for each key
+ and will also affect the generated column list in an INSERT statement and the SET
+ clauses of an UPDATE statement. The keys of the parameter dictionary can
+ either be the string names of columns or actual sqlalchemy.schema.Column objects."""
self.engine = engine
- self.bindparams = bindparams
+ self.parameters = parameters
self.statement = statement
def __str__(self):
"""returns the string text of the generated SQL statement."""
raise NotImplementedError()
def get_params(self, **params):
- """returns the bind params for this compiled object, with values overridden by
- those given in the **params dictionary"""
+ """returns the bind params for this compiled object.
+
+ Will start with the default parameters specified when this Compiled object
+ was first constructed, and will override those values with those sent via
+ **params, which are key/value pairs. Each key should match one of the
+ BindParamClause objects compiled into this object; either the "key" or
+ "shortname" property of the BindParamClause.
+ """
raise NotImplementedError()
def execute(self, *multiparams, **params):
else:
params = self.get_params(**params)
- return self.engine.execute(str(self), params, compiled=self, typemap=self.typemap)
+ return self.engine.execute_compiled(self, params)
def scalar(self, *multiparams, **params):
"""executes this compiled object via the execute() method, then
return [self]
columns = property(lambda s: s._get_columns())
- def compile(self, engine = None, bindparams = None, typemap=None):
+ def compile(self, engine = None, parameters = None, typemap=None):
"""compiles this SQL expression using its underlying SQLEngine to produce
a Compiled object. If no engine can be found, an ansisql engine is used.
bindparams is a dictionary representing the default bind parameters to be used with
if engine is None:
raise "no SQLEngine could be located within this ClauseElement."
- return engine.compile(self, bindparams = bindparams, typemap=typemap)
+ return engine.compile(self, parameters=parameters, typemap=typemap)
def __str__(self):
e = self.engine
bindparams = multiparams[0]
else:
bindparams = params
- c = self.compile(e, bindparams = bindparams)
+ c = self.compile(e, parameters=bindparams)
return c.execute(*multiparams, **params)
def scalar(self, *multiparams, **params):
print repr(users_with_date.select().execute().fetchall())
users_with_date.drop()
+ def testdefaults(self):
+ x = {'x':0}
+ def mydefault():
+ x['x'] += 1
+ return x['x']
+
+ t = Table('default_test1', db,
+ Column('col1', Integer, primary_key=True, default=mydefault),
+ Column('col2', String(20), default="imthedefault"),
+ Column('col3', String(20), default=func.count(1)),
+ )
+ t.create()
+ t.insert().execute()
+ t.insert().execute()
+ t.insert().execute()
+ t.drop()
+
def testdelete(self):
c = db.connection()
# check that the bind params sent along with a compile() call
# get preserved when the params are retreived later
s = select([table], table.c.id == bindparam('test'))
- c = s.compile(bindparams = {'test' : 7})
+ c = s.compile(parameters = {'test' : 7})
self.assert_(c.get_params() == {'test' : 7})
def testcorrelatedsubquery(self):
self.runtest(update(table, table.c.id == 12, values = {table.c.id : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
s = table.update(table.c.id == 12, values = {table.c.name : 'lala'})
print str(s)
- c = s.compile(bindparams = {'mytable_id':9,'name':'h0h0'})
+ c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'})
print str(c)
self.assert_(str(s) == str(c))
users = Table('users', db,
- Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+ Column('user_id', Integer, Sequence('user_id_seq', optional=False), primary_key = True),
Column('user_name', String(40)),
)
"""decorates a SQLEngine object to match the incoming queries against a set of assertions."""
def __init__(self, engine):
self.engine = engine
- self.realexec = engine.execute
- engine.execute = self.execute
+ self.realexec = engine.execute_compiled
+ engine.execute_compiled = self.execute_compiled
self.echo = engine.echo
self.logger = engine.logger
self.set_assert_list(None, None)
self.assert_list = list
if list is not None:
self.assert_list.reverse()
- def execute(self, statement, parameters, **kwargs):
+ def execute_compiled(self, compiled, parameters, **kwargs):
self.engine.echo = self.echo
self.engine.logger = self.logger
+ statement = str(compiled)
if self.assert_list is not None:
item = self.assert_list.pop()
params = params()
# deal with paramstyles of different engines
- if isinstance(self.engine, sqlite.SQLiteSQLEngine):
- paramstyle = 'named'
- else:
- db = self.engine.dbapi()
- if db is not None:
- paramstyle = db.paramstyle
- else:
- paramstyle = 'named'
+ paramstyle = self.engine.paramstyle
if paramstyle == 'named':
pass
elif paramstyle =='pyformat':
elif paramstyle=='numeric':
repl = None
counter = 0
- def append_arg(match):
- names.append(match.group(1))
- if repl is None:
- counter += 1
- return counter
- else:
- return repl
- # substitute bind string in query, translate bind param
- # dict to a list (or a list of dicts to a list of lists)
- query = re.sub(r':([\w_]+)', append_arg, query)
- if isinstance(params, list):
- args = []
- for p in params:
- l = []
- args.append(l)
- for n in names:
- l.append(p[n])
- else:
- args = []
- for n in names:
- args.append(params[n])
- params = args
+ query = re.sub(r':([\w_]+)', repl, query)
self.unittest.assert_(statement == query and params == parameters, "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
- return self.realexec(statement, parameters, **kwargs)
+ return self.realexec(compiled, parameters, **kwargs)
class TTestSuite(unittest.TestSuite):