return self.wheres.get(obj, None)
def get_params(self, **params):
- """returns the bind params for this compiled object, with values overridden by
- those given in the **params dictionary"""
+ """returns a structure of bind parameters for this compiled object.
+ This includes bind parameters that might be compiled in via the "values"
+ argument of an Insert or Update statement object, and also the given **params.
+ The keys inside of **params can be any key that matches the BindParameterClause
+ objects compiled within this object. The output is dependent on the paramstyle
+ of the DBAPI being used; if a named style, the return result will be a dictionary
+ with keynames matching the compiled statement. If a positional style, the output
+ will be a list corresponding to the bind positions in the compiled statement.
+
+ for an executemany style of call, this method should be called for each element
+ in the list of parameter groups that will ultimately be executed.
+ """
d = {}
if self.bindparams is not None:
bindparams = self.bindparams.copy()
else:
bindparams = {}
bindparams.update(params)
+ # TODO: cant we make "d" an ordereddict and add params in
+ # positional order
for key, value in bindparams.iteritems():
try:
b = self.binds[key]
else:
return d
+ def get_named_params(self, parameters):
+ """given the results of the get_params method, returns the parameters
+ in dictionary format. For a named paramstyle, this just returns the
+ same dictionary. For a positional paramstyle, the given parameters are
+ assumed to be in list format and are converted back to a dictionary.
+ """
+ if self.positional:
+ p = {}
+ for i in range(0, len(self.positiontup)):
+ p[self.positiontup[i]] = parameters[i]
+ return p
+ else:
+ return parameters
+
def visit_column(self, column):
if len(self.select_stack):
# if we are within a visit to a Select, set up the "typemap"
self.execute()
+class ANSIDefaultRunner(sqlalchemy.engine.DefaultRunner):
+ pass
\ No newline at end of file
def last_inserted_ids(self):
return self.context.last_inserted_ids
- def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
if compiled is None: return
if getattr(compiled, "isinsert", False):
- self.context.last_inserted_ids = [cursor.lastrowid]
+ self.context.last_inserted_ids = [proxy().lastrowid]
# executemany just runs normally, since we arent using rowcount at all with mysql
# def _executemany(self, c, statement, parameters):
def _rowid_col(self):
if getattr(self, '_mysql_rowid_column', None) is None:
if len(self.table.primary_key) > 0:
- self._mysql_rowid_column = self.table.primary_key[0]
+ c = self.table.primary_key[0]
else:
- self._mysql_rowid_column = self.table.columns[self.table.columns.keys()[0]]
+ c = self.table.columns[self.table.columns.keys()[0]]
+ self._mysql_rowid_column = schema.Column(c.name, c.type, hidden=True)
+ self._mysql_rowid_column._set_parent(self.table)
+
return self._mysql_rowid_column
rowid_column = property(lambda s: s._rowid_col())
if column.primary_key:
if not override_pk:
colspec += " PRIMARY KEY"
- if first_pk:
+ if first_pk and isinstance(column.type, types.Integer):
colspec += " AUTO_INCREMENT"
if column.foreign_key:
colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
return OracleSchemaGenerator(proxy, **params)
def schemadropper(self, proxy, **params):
return OracleSchemaDropper(proxy, **params)
-
+ def defaultrunner(self, proxy):
+ return OracleDefaultRunner(proxy)
+
def reflecttable(self, table):
raise "not implemented"
def last_inserted_ids(self):
return self.context.last_inserted_ids
- def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ def pre_exec(self, proxy, statement, parameters, compiled=None, **kwargs):
if compiled is None: return
+ # this is just an assertion that all the primary key columns in an insert statement
+ # have a value set up, or have a default generator ready to go
if getattr(compiled, "isinsert", False):
if isinstance(parameters, list):
plist = parameters
else:
plist = [parameters]
for param in plist:
- last_inserted_ids = []
for primary_key in compiled.statement.table.primary_key:
if not param.has_key(primary_key.key) or param[primary_key.key] is None:
- if primary_key.sequence is None:
- raise "Column '%s.%s': Oracle primary key columns require schema.Sequence to create ids" % (primary_key.table.name, primary_key.name)
- if echo is True or self.echo:
- self.log("select %s.nextval from dual" % primary_key.sequence.name)
- cursor.execute("select %s.nextval from dual" % primary_key.sequence.name)
- newid = cursor.fetchone()[0]
- param[primary_key.key] = newid
- last_inserted_ids.append(param[primary_key.key])
- self.context.last_inserted_ids = last_inserted_ids
-
- def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
- pass
+ if primary_key.default is None:
+ raise "Column '%s.%s': Oracle primary key columns require a default value or a schema.Sequence to create ids" % (primary_key.table.name, primary_key.name)
def _executemany(self, c, statement, parameters):
rowcount = 0
self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
self.wheres[join] = join.onclause
- print "check1"
if join.isouter:
- print "check2"
# if outer join, push on the right side table as the current "outertable"
outertable = self._outertable
self._outertable = join.right
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
+class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
+ def visit_sequence(self, seq):
+ c = self.proxy("select %s.nextval from dual" % seq.name)
+ return c.fetchone()[0]
def schemadropper(self, proxy, **params):
return PGSchemaDropper(proxy, **params)
+
+ def defaultrunner(self, proxy):
+ return PGDefaultRunner(proxy)
def get_default_schema_name(self):
if not hasattr(self, '_default_schema_name'):
return self._default_schema_name
def last_inserted_ids(self):
- # if we used sequences or already had all values for the last inserted row,
- # return that list
- if self.context.last_inserted_ids is not None:
- return self.context.last_inserted_ids
-
- # else we have to use lastrowid and select the most recently inserted row
- table = self.context.last_inserted_table
- if self.context.lastrowid is not None and table is not None and len(table.primary_key):
- row = sql.select(table.primary_key, table.rowid_column == self.context.lastrowid).execute().fetchone()
- return [v for v in row]
- else:
- return None
-
- def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ return self.context.last_inserted_ids
+
+ def pre_exec(self, proxy, statement, parameters, **kwargs):
+ return
+
+ def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
if compiled is None: return
- if getattr(compiled, "isinsert", False):
- if isinstance(parameters, list):
- plist = parameters
- else:
- plist = [parameters]
- # inserts are usually one at a time. but if we got a list of parameters,
- # it will calculate last_inserted_ids for just the last row in the list.
- # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
- # it or post-select anyway
- for param in plist:
- last_inserted_ids = []
- need_lastrowid=False
- for primary_key in compiled.statement.table.primary_key:
- if not param.has_key(primary_key.key) or param[primary_key.key] is None:
- if primary_key.sequence is not None and not primary_key.sequence.optional:
- if echo is True or self.echo:
- self.log("select nextval('%s')" % primary_key.sequence.name)
- cursor.execute("select nextval('%s')" % primary_key.sequence.name)
- newid = cursor.fetchone()[0]
- param[primary_key.key] = newid
- last_inserted_ids.append(param[primary_key.key])
- else:
- need_lastrowid = True
- else:
- last_inserted_ids.append(param[primary_key.key])
- if need_lastrowid:
- self.context.last_inserted_ids = None
- else:
- self.context.last_inserted_ids = last_inserted_ids
+ if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
+ table = compiled.statement.table
+ cursor = proxy()
+ if cursor.lastrowid is not None and table is not None and len(table.primary_key):
+ s = sql.select(table.primary_key, table.rowid_column == cursor.lastrowid)
+ c = s.compile()
+ cursor = proxy(str(c), c.get_params())
+ row = cursor.fetchone()
+ self.context.last_inserted_ids = [v for v in row]
def _executemany(self, c, statement, parameters):
"""we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough
rowcount += c.rowcount
self.context.rowcount = rowcount
- def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
- if compiled is None: return
- if getattr(compiled, "isinsert", False):
- table = compiled.statement.table
- self.context.last_inserted_table = table
- self.context.lastrowid = cursor.lastrowid
def dbapi(self):
return self.module
with autoincrement fields that require they not be present. so,
put them all in for columns where sequence usage is defined."""
for c in insert.table.primary_key:
- if c.sequence is not None and not c.sequence.optional:
+ if self.bindparams.get(c.key, None) is None and c.default is not None and not c.default.optional:
self.bindparams[c.key] = None
return ansisql.ANSICompiler.visit_insert(self, insert)
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional):
+ if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or column.default.optional):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
if not sequence.optional:
self.append("DROP SEQUENCE %s" % sequence.name)
self.execute()
+
+class PGDefaultRunner(ansisql.ANSIDefaultRunner):
+ def visit_sequence(self, seq):
+ if not seq.optional:
+ c = self.proxy("select nextval('%s')" % seq.name)
+ return c.fetchone()[0]
+ else:
+ return None
\ No newline at end of file
params['poolclass'] = sqlalchemy.pool.SingletonThreadPool
ansisql.ANSISQLEngine.__init__(self, **params)
- def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ def post_exec(self, proxy, statement, parameters, compiled = None, **kwargs):
if compiled is None: return
if getattr(compiled, "isinsert", False):
- self.context.last_inserted_ids = [cursor.lastrowid]
+ self.context.last_inserted_ids = [proxy().lastrowid]
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
self.append("\n)\n\n")
self.execute()
+
finally:
self.buffer.truncate(0)
+class DefaultRunner(schema.SchemaVisitor):
+ def __init__(self, proxy):
+ self.proxy = proxy
+
+ def visit_sequence(self, seq):
+ """sequences are not supported by default"""
+ return None
+
+ def visit_column_default(self, default):
+ if isinstance(default.arg, ClauseElement):
+ c = default.arg.compile()
+ return proxy.execute(str(c), c.get_params())
+ elif callable(default.arg):
+ return default.arg()
+ else:
+ return default.arg
+
+
class SQLEngine(schema.SchemaEngine):
"""base class for a series of database-specific engines. serves as an abstract factory
for implementation objects as well as database connections, transactions, SQL generators,
def schemadropper(self, proxy, **params):
raise NotImplementedError()
+ def defaultrunner(self, proxy):
+ return DefaultRunner(proxy)
+
def compiler(self, statement, bindparams):
raise NotImplementedError()
self.context.transaction = None
self.context.tcount = None
-
- def _process_sequences(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+ def _process_defaults(self, proxy, statement, parameters, compiled=None, **kwargs):
if compiled is None: return
if getattr(compiled, "isinsert", False):
- if isinstance(parameters, list):
+ # TODO: this sucks. we have to get "parameters" to be a more standardized object
+ if isinstance(parameters, list) and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
plist = parameters
else:
plist = [parameters]
# it will calculate last_inserted_ids for just the last row in the list.
# TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
# it or post-select anyway
+ drunner = self.defaultrunner(proxy)
for param in plist:
+ # the parameters might be positional, so convert them
+ # to a dictionary
+ # TODO: this is stupid. or, is this stupid ?
+ # any way we can just have an OrderedDict so we have the
+ # dictionary + postional version each time ?
+ param = compiled.get_named_params(param)
last_inserted_ids = []
need_lastrowid=False
for c in compiled.statement.table.c:
if not param.has_key(c.key) or param[c.key] is None:
- if c.sequence is not None:
- newid = self.exec_sequence(c.sequence)
+ if c.default is not None:
+ newid = c.default.accept_visitor(drunner)
else:
newid = None
self.context.last_inserted_ids = None
else:
self.context.last_inserted_ids = last_inserted_ids
-
- def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+
+
+ def pre_exec(self, proxy, statement, parameters, **kwargs):
pass
- def post_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
+ def post_exec(self, proxy, statement, parameters, **kwargs):
pass
- def execute(self, statement, parameters, connection = None, echo = None, typemap = None, commit=False, **kwargs):
- """executes the given string-based SQL statement with the given parameters. This is
- a direct interface to a DBAPI connection object. The parameters may be a dictionary,
- or an array of dictionaries. If an array of dictionaries is sent, executemany will
- be performed on the cursor instead of execute.
+ def execute(self, statement, parameters, connection=None, cursor=None, echo = None, typemap = None, commit=False, **kwargs):
+ """executes the given string-based SQL statement with the given parameters.
+ The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
+ on the paramstyle of the DBAPI.
+
If the current thread has specified a transaction begin() for this engine, the
statement will be executed in the context of the current transactional connection.
Otherwise, a commit() will be performed immediately after execution, since the local
if connection is None:
connection = self.connection()
- c = connection.cursor()
- else:
- c = connection.cursor()
- try:
- self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs)
- #self._process_sequences(connection, c, statement, parameters, echo = echo, **kwargs)
-
+ if cursor is None:
+ cursor = connection.cursor()
+
+ def proxy(statement=None, parameters=None):
+ if statement is None:
+ return cursor
if echo is True or self.echo is not False:
self.log(statement)
self.log(repr(parameters))
- if isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
- self._executemany(c, statement, parameters)
+ if parameters is not None and isinstance(parameters, list) and len(parameters) > 0 and (isinstance(parameters[0], list) or isinstance(parameters[0], dict)):
+ self._executemany(cursor, statement, parameters)
else:
- self._execute(c, statement, parameters)
- self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+ self._execute(cursor, statement, parameters)
+ return cursor
+
+ try:
+ self.pre_exec(proxy, statement, parameters, **kwargs)
+ self._process_defaults(proxy, statement, parameters, **kwargs)
+ proxy(statement, parameters)
+ self.post_exec(proxy, statement, parameters, **kwargs)
if commit or self.context.transaction is None:
self.do_commit(connection)
except:
self.do_rollback(connection)
# TODO: wrap DB exceptions ?
raise
- return ResultProxy(c, self, typemap = typemap)
+ return ResultProxy(cursor, self, typemap = typemap)
def _execute(self, c, statement, parameters):
c.execute(statement, parameters)
class TableSingleton(type):
def __call__(self, name, engine, *args, **kwargs):
try:
+ name = str(name) # in case of incoming unicode
schema = kwargs.get('schema', None)
autoload = kwargs.pop('autoload', False)
redefine = kwargs.pop('redefine', False)
class Column(SchemaItem):
"""represents a column in a database table."""
def __init__(self, name, type, *args, **kwargs):
- self.name = name
+ self.name = str(name) # in case of incoming unicode
self.type = type
self.args = args
self.key = kwargs.pop('key', name)
self.primary_key = kwargs.pop('primary_key', False)
self.nullable = kwargs.pop('nullable', not self.primary_key)
self.hidden = kwargs.pop('hidden', False)
+ self.default = kwargs.pop('default', None)
self.foreign_key = None
- self.sequence = None
self._orig = None
if len(kwargs):
raise "Unknown arguments passed to Column: " + repr(kwargs.keys())
self._impl = self.table.engine.columnimpl(self)
+ if self.default is not None:
+ self._init_items(self.default)
self._init_items(*self.args)
self.args = None
fk = None
else:
fk = self.foreign_key.copy()
- return Column(self.name, self.type, fk, self.sequence, key = self.key, primary_key = self.primary_key)
+ return Column(self.name, self.type, fk, self.default, key = self.key, primary_key = self.primary_key)
def _make_proxy(self, selectable, name = None):
"""creates a copy of this Column, initialized the way this Column is"""
fk = None
else:
fk = self.foreign_key.copy()
- c = Column(name or self.name, self.type, fk, self.sequence, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden)
+ c = Column(name or self.name, self.type, fk, self.default, key = name or self.key, primary_key = self.primary_key, hidden=self.hidden)
c.table = selectable
c._orig = self.original
if not c.hidden:
return c
def accept_visitor(self, visitor):
- if self.sequence is not None:
- self.sequence.accept_visitor(visitor)
+ if self.default is not None:
+ self.default.accept_visitor(visitor)
if self.foreign_key is not None:
self.foreign_key.accept_visitor(visitor)
visitor.visit_column(self)
visitor.visit_foreign_key(self)
def _set_parent(self, column):
- if not isinstance(column, Column):
- raise "hi" + repr(type(column))
self.parent = column
self.parent.foreign_key = self
self.parent.table.foreign_keys.append(self)
+
+class DefaultGenerator(SchemaItem):
+ """represents a "default value generator" for a particular column in a particular
+ table. This could correspond to a constant, a callable function, or a SQL clause."""
+ def _set_parent(self, column):
+ self.column = column
+ self.column.default = self
+ def accept_visitor(self, visitor):
+ pass
+
+class ColumnDefault(DefaultGenerator):
+ def __init__(self, arg):
+ self.arg = arg
+ def accept_visitor(self, visitor):
+ return visitor.visit_column_default(self)
-class Sequence(SchemaItem):
+class Sequence(DefaultGenerator):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, func = None, start = None, increment = None, optional=False):
+ def __init__(self, name, start = None, increment = None, optional=False):
self.name = name
- self.func = func
self.start = start
self.increment = increment
self.optional=optional
- def _set_parent(self, column):
- self.column = column
- self.column.sequence = self
def accept_visitor(self, visitor):
return visitor.visit_sequence(self)
def visit_column(self, column):pass
def visit_foreign_key(self, join):pass
def visit_index(self, index):pass
+ def visit_column_default(self, default):pass
def visit_sequence(self, sequence):pass
params = [self.get_params(**m) for m in multiparams]
else:
params = self.get_params(**params)
- return self.engine.execute(str(self), params, compiled = self, typemap = self.typemap)
+
+ return self.engine.execute(str(self), params, compiled=self, typemap=self.typemap)
def scalar(self, *multiparams, **params):
"""executes this compiled object via the execute() method, then
return self
def adapt_type(typeobj, colspecs):
+ """given a generic type from this package, and a dictionary of
+ "conversion" specs from a DB-specific package, adapts the type
+ to a correctly-configured type instance from the DB-specific package."""
if type(typeobj) is type:
typeobj = typeobj()
typeobj = typeobj.adapt_args()
return c
class String(NullTypeEngine):
- def __init__(self, length = None):
+ def __init__(self, length = None, is_unicode=False):
self.length = length
+ self.is_unicode = is_unicode
def adapt(self, typeobj):
return typeobj(self.length)
def adapt_args(self):
if self.length is None:
- return TEXT()
+ return TEXT(is_unicode=self.is_unicode)
else:
return self
+
+class Unicode(String):
+ def __init__(self, length=None):
+ String.__init__(self, length, is_unicode=True)
+ def adapt(self, typeobj):
+ return typeobj(self.length, is_unicode=True)
class Integer(NullTypeEngine):
"""integer datatype"""