import sqlalchemy.ansisql as ansisql
from sqlalchemy.ansisql import *
+try:
+ import cx_Oracle
+except:
+ cx_Oracle = None
+
+class OracleNumeric(sqltypes.Numeric):
+ def get_col_spec(self):
+ return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+class OracleInteger(sqltypes.Integer):
+ def get_col_spec(self):
+ return "INTEGER"
+class OracleDateTime(sqltypes.DateTime):
+ def get_col_spec(self):
+ return "TIMESTAMP"
+class OracleText(sqltypes.TEXT):
+ def get_col_spec(self):
+ return "TEXT"
+class OracleString(sqltypes.String):
+ def get_col_spec(self):
+ return "VARCHAR(%(length)s)" % {'length' : self.length}
+class OracleChar(sqltypes.CHAR):
+ def get_col_spec(self):
+ return "CHAR(%(length)s)" % {'length' : self.length}
+class OracleBinary(sqltypes.Binary):
+ def get_col_spec(self):
+ return "BLOB"
+class OracleBoolean(sqltypes.Boolean):
+ def get_col_spec(self):
+ return "BOOLEAN"
+
+colspecs = {
+ sqltypes.Integer : OracleInteger,
+ sqltypes.Numeric : OracleNumeric,
+ sqltypes.DateTime : OracleDateTime,
+ sqltypes.String : OracleString,
+ sqltypes.Binary : OracleBinary,
+ sqltypes.Boolean : OracleBoolean,
+ sqltypes.TEXT : OracleText,
+ sqltypes.CHAR: OracleChar,
+}
def engine(**params):
return OracleSQLEngine(**params)
class OracleSQLEngine(ansisql.ANSISQLEngine):
- def __init__(self, use_ansi = True, **params):
+ def __init__(self, opts, use_ansi = True, module = None, **params):
self._use_ansi = use_ansi
ansisql.ANSISQLEngine.__init__(self, **params)
+ self.opts = {}
+ if module is None:
+ self.module = cx_Oracle
+ else:
+ self.module = module
+ def dbapi(self):
+ return self.module
+
+ def connect_args(self):
+ return [[], self.opts]
+
def compile(self, statement, bindparams):
compiler = OracleCompiler(self, statement, bindparams, use_ansi = self._use_ansi)
statement.accept_visitor(compiler)
return compiler
-
- def create_connection(self):
- raise NotImplementedError()
+
+ def type_descriptor(self, typeobj):
+ return sqltypes.adapt_type(typeobj, colspecs)
+
+ def last_inserted_ids(self):
+ return self.context.last_inserted_ids
+
+ def compiler(self, statement, bindparams):
+ return OracleCompiler(self, statement, bindparams)
+
+ def schemagenerator(self, proxy, **params):
+ return OracleSchemaGenerator(proxy, **params)
+
+ def reflecttable(self, table):
+ raise "not implemented"
+
+ def last_inserted_ids(self):
+ table = self.context.last_inserted_table
+ if self.context.lastrowid is not None and table is not None and len(table.primary_keys):
+ row = sql.select(table.primary_keys, table.rowid_column == self.context.lastrowid).execute().fetchone()
+ return [v for v in row]
+ else:
+ return None
+
+ def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ pass
+ def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+ pass
+
+ def _executemany(self, c, statement, parameters):
+ rowcount = 0
+ for param in parameters:
+ c.execute(statement, param)
+ rowcount += c.rowcount
+ self.context.rowcount = rowcount
class OracleCompiler(ansisql.ANSICompiler):
"""oracle compiler modifies the lexical structure of Select statements to work under
self.strings[column] = "%s.%s" % (column.table.name, column.name)
+class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
+ def get_column_specification(self, column):
+ colspec = column.name
+ colspec += " " + column.type.get_col_spec()
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+ if column.primary_key:
+ colspec += " PRIMARY KEY"
+ if column.foreign_key:
+ colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
+ return colspec
+
+ def visit_sequence(self, sequence):
+ self.append("CREATE SEQUENCE %s" % sequence.name)
+ self.execute()
\ No newline at end of file
def connect_args(self):
return [[], self.opts]
-
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
def schemagenerator(self, proxy, **params):
return PGSchemaGenerator(proxy, **params)
+ def schemadropper(self, proxy, **params):
+ return PGSchemaDropper(proxy, **params)
+
def reflecttable(self, table):
raise "not implemented"
return None
def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
- if True: return
# if a sequence was explicitly defined we do it here
if compiled is None: return
if getattr(compiled, "isinsert", False):
- last_inserted_ids = []
for primary_key in compiled.statement.table.primary_keys:
- # pseudocode
- if parameters[primary_key.key] is None:
- if echo is True:
- self.log(primary_key.sequence.text)
- res = cursor.execute(primary_key.sequence.text)
- newid = res.fetchrow()[0]
+ if primary_key.sequence is not None and not primary_key.sequence.optional and parameters[primary_key.key] is None:
+ if echo is True or self.echo:
+ self.log("select nextval('%s')" % primary_key.sequence.name)
+ cursor.execute("select nextval('%s')" % primary_key.sequence.name)
+ newid = cursor.fetchone()[0]
parameters[primary_key.key] = newid
- last_inserted_ids.append(newid)
- self.context.last_inserted_ids = last_inserted_ids
def _executemany(self, c, statement, parameters):
"""we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough
def bindparam_string(self, name):
return "%(" + name + ")s"
+ def visit_insert(self, insert):
+ for c in insert.table.primary_keys:
+ if c.sequence is not None and not c.sequence.optional:
+ self.bindparams[c.key] = None
+ #if not insert.parameters.has_key(c.key):
+ # insert.parameters[c.key] = sql.bindparam(c.key)
+ return ansisql.ANSICompiler.visit_insert(self, insert)
+
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column):
colspec = column.name
- if column.primary_key and isinstance(column.type, types.Integer):
+ if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional):
colspec += " SERIAL"
else:
- colspec += " " + column.column.type.get_col_spec()
+ colspec += " " + column.type.get_col_spec()
if not column.nullable:
colspec += " NOT NULL"
if column.foreign_key:
colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
return colspec
+
+ def visit_sequence(self, sequence):
+ if not sequence.optional:
+ self.append("CREATE SEQUENCE %s" % sequence.name)
+ self.execute()
+
+class PGSchemaDropper(ansisql.ANSISchemaDropper):
+ def visit_sequence(self, sequence):
+ if not sequence.optional:
+ self.append("DROP SEQUENCE %s" % sequence.name)
+ self.execute()
connection.commit()
def proxy(self):
- return lambda s, p = None: self.execute(s, p)
+ return lambda s, p = None: self.execute(s, p, commit=True)
def connection(self):
return self._pool.connect()
def post_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
pass
- def execute(self, statement, parameters, connection = None, echo = None, typemap = None, **kwargs):
+ def execute(self, statement, parameters, connection = None, echo = None, typemap = None, commit=False, **kwargs):
if parameters is None:
parameters = {}
if echo is True or self.echo:
self.log(repr(parameters))
if connection is None:
- poolconn = self.connection()
- c = poolconn.cursor()
+ connection = self.connection()
+ c = connection.cursor()
else:
c = connection.cursor()
else:
self._execute(c, statement, parameters)
self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+ if commit:
+ connection.commit()
return ResultProxy(c, self, typemap = typemap)
def _execute(self, c, statement, parameters):
c._impl = self.engine.columnimpl(c)
return c
- def accept_visitor(self, visitor):
- return visitor.visit_column(self)
+ def accept_visitor(self, visitor):
+ if self.sequence is not None:
+ self.sequence.accept_visitor(visitor)
+ if self.foreign_key is not None:
+ self.foreign_key.accept_visitor(visitor)
+ visitor.visit_column(self)
def __lt__(self, other): return self._impl.__lt__(other)
def __le__(self, other): return self._impl.__le__(other)
return self._column
column = property(lambda s: s._init_column())
-
+
+ def accept_visitor(self, visitor):
+ visitor.visit_foreign_key(self)
+
def _set_parent(self, column):
self.parent = column
self.parent.foreign_key = self
class Sequence(SchemaItem):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, start = None, increment = None):
+ def __init__(self, name, start = None, increment = None, optional=False):
self.name = name
self.start = start
self.increment = increment
- def _set_parent(self, column, key):
+ self.optional=optional
+ def _set_parent(self, column):
self.column = column
self.column.sequence = self
def accept_visitor(self, visitor):
raise NotImplementedError()
class SchemaVisitor(object):
- """base class for an object that traverses across Schema objects"""
-
- def visit_schema(self, schema):pass
- def visit_table(self, table):pass
- def visit_column(self, column):pass
- def visit_foreign_key(self, join):pass
- def visit_index(self, index):pass
- def visit_sequence(self, sequence):pass
+ """base class for an object that traverses across Schema objects"""
+
+ def visit_schema(self, schema):pass
+ def visit_table(self, table):pass
+ def visit_column(self, column):pass
+ def visit_foreign_key(self, join):pass
+ def visit_index(self, index):pass
+ def visit_sequence(self, sequence):pass
\ No newline at end of file
item.keywords.append(ik)
objectstore.uow().commit()
-
+ objectstore.clear()
l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name])
self.assert_result(l, *data)
db = testbase.EngineAssert(db)
users = Table('users', db,
- Column('user_id', Integer, primary_key = True),
+ Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
Column('user_name', String(40)),
)
addresses = Table('email_addresses', db,
- Column('address_id', Integer, primary_key = True),
+ Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
Column('user_id', Integer, ForeignKey(users.c.user_id)),
Column('email_address', String(40)),
)
orderitems.create()
keywords.create()
itemkeywords.create()
- db.commit()
def drop():
itemkeywords.drop()
orders.drop()
addresses.drop()
users.drop()
- db.commit()
def delete():
itemkeywords.delete().execute()
orders.delete().execute()
addresses.delete().execute()
users.delete().execute()
+ db.commit()
def data():
delete()