From 5512e6add17aae602147614b454ae190caf1f704 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 23 Oct 2005 16:36:31 +0000 Subject: [PATCH] sequences, oracle --- lib/sqlalchemy/databases/oracle.py | 107 ++++++++++++++++++++++++++- lib/sqlalchemy/databases/postgres.py | 42 +++++++---- lib/sqlalchemy/engine.py | 10 ++- lib/sqlalchemy/schema.py | 34 +++++---- test/objectstore.py | 2 +- test/tables.py | 7 +- 6 files changed, 163 insertions(+), 39 deletions(-) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 6b6fea8307..5922863b44 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -23,22 +23,105 @@ import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql from sqlalchemy.ansisql import * +try: + import cx_Oracle +except: + cx_Oracle = None + +class OracleNumeric(sqltypes.Numeric): + def get_col_spec(self): + return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} +class OracleInteger(sqltypes.Integer): + def get_col_spec(self): + return "INTEGER" +class OracleDateTime(sqltypes.DateTime): + def get_col_spec(self): + return "TIMESTAMP" +class OracleText(sqltypes.TEXT): + def get_col_spec(self): + return "TEXT" +class OracleString(sqltypes.String): + def get_col_spec(self): + return "VARCHAR(%(length)s)" % {'length' : self.length} +class OracleChar(sqltypes.CHAR): + def get_col_spec(self): + return "CHAR(%(length)s)" % {'length' : self.length} +class OracleBinary(sqltypes.Binary): + def get_col_spec(self): + return "BLOB" +class OracleBoolean(sqltypes.Boolean): + def get_col_spec(self): + return "BOOLEAN" + +colspecs = { + sqltypes.Integer : OracleInteger, + sqltypes.Numeric : OracleNumeric, + sqltypes.DateTime : OracleDateTime, + sqltypes.String : OracleString, + sqltypes.Binary : OracleBinary, + sqltypes.Boolean : OracleBoolean, + sqltypes.TEXT : OracleText, + sqltypes.CHAR: OracleChar, +} def engine(**params): return OracleSQLEngine(**params) class OracleSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, use_ansi = True, **params): + def __init__(self, opts, use_ansi = True, module = None, **params): self._use_ansi = use_ansi ansisql.ANSISQLEngine.__init__(self, **params) + self.opts = {} + if module is None: + self.module = cx_Oracle + else: + self.module = module + def dbapi(self): + return self.module + + def connect_args(self): + return [[], self.opts] + def compile(self, statement, bindparams): compiler = OracleCompiler(self, statement, bindparams, use_ansi = self._use_ansi) statement.accept_visitor(compiler) return compiler - - def create_connection(self): - raise NotImplementedError() + + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, colspecs) + + def last_inserted_ids(self): + return self.context.last_inserted_ids + + def compiler(self, statement, bindparams): + return OracleCompiler(self, statement, bindparams) + + def schemagenerator(self, proxy, **params): + return OracleSchemaGenerator(proxy, **params) + + def reflecttable(self, table): + raise "not implemented" + + def last_inserted_ids(self): + table = self.context.last_inserted_table + if self.context.lastrowid is not None and table is not None and len(table.primary_keys): + row = sql.select(table.primary_keys, table.rowid_column == self.context.lastrowid).execute().fetchone() + return [v for v in row] + else: + return None + + def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): + pass + def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): + pass + + def _executemany(self, c, statement, parameters): + rowcount = 0 + for param in parameters: + c.execute(statement, param) + rowcount += c.rowcount + self.context.rowcount = rowcount class OracleCompiler(ansisql.ANSICompiler): """oracle compiler modifies the lexical structure of Select statements to work under @@ -77,3 +160,19 @@ class OracleCompiler(ansisql.ANSICompiler): self.strings[column] = "%s.%s" % (column.table.name, column.name) +class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): + def get_column_specification(self, column): + colspec = column.name + colspec += " " + column.type.get_col_spec() + + if not column.nullable: + colspec += " NOT NULL" + if column.primary_key: + colspec += " PRIMARY KEY" + if column.foreign_key: + colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) + return colspec + + def visit_sequence(self, sequence): + self.append("CREATE SEQUENCE %s" % sequence.name) + self.execute() \ No newline at end of file diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 547a2c0b78..76d5248f77 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -84,7 +84,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def connect_args(self): return [[], self.opts] - def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) @@ -97,6 +96,9 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def schemagenerator(self, proxy, **params): return PGSchemaGenerator(proxy, **params) + def schemadropper(self, proxy, **params): + return PGSchemaDropper(proxy, **params) + def reflecttable(self, table): raise "not implemented" @@ -109,21 +111,16 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return None def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): - if True: return # if a sequence was explicitly defined we do it here if compiled is None: return if getattr(compiled, "isinsert", False): - last_inserted_ids = [] for primary_key in compiled.statement.table.primary_keys: - # pseudocode - if parameters[primary_key.key] is None: - if echo is True: - self.log(primary_key.sequence.text) - res = cursor.execute(primary_key.sequence.text) - newid = res.fetchrow()[0] + if primary_key.sequence is not None and not primary_key.sequence.optional and parameters[primary_key.key] is None: + if echo is True or self.echo: + self.log("select nextval('%s')" % primary_key.sequence.name) + cursor.execute("select nextval('%s')" % primary_key.sequence.name) + newid = cursor.fetchone()[0] parameters[primary_key.key] = newid - last_inserted_ids.append(newid) - self.context.last_inserted_ids = last_inserted_ids def _executemany(self, c, statement, parameters): """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough @@ -151,13 +148,21 @@ class PGCompiler(ansisql.ANSICompiler): def bindparam_string(self, name): return "%(" + name + ")s" + def visit_insert(self, insert): + for c in insert.table.primary_keys: + if c.sequence is not None and not c.sequence.optional: + self.bindparams[c.key] = None + #if not insert.parameters.has_key(c.key): + # insert.parameters[c.key] = sql.bindparam(c.key) + return ansisql.ANSICompiler.visit_insert(self, insert) + class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column): colspec = column.name - if column.primary_key and isinstance(column.type, types.Integer): + if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional): colspec += " SERIAL" else: - colspec += " " + column.column.type.get_col_spec() + colspec += " " + column.type.get_col_spec() if not column.nullable: colspec += " NOT NULL" @@ -166,3 +171,14 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if column.foreign_key: colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) return colspec + + def visit_sequence(self, sequence): + if not sequence.optional: + self.append("CREATE SEQUENCE %s" % sequence.name) + self.execute() + +class PGSchemaDropper(ansisql.ANSISchemaDropper): + def visit_sequence(self, sequence): + if not sequence.optional: + self.append("DROP SEQUENCE %s" % sequence.name) + self.execute() diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 4e01d1684d..9957fe315e 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -124,7 +124,7 @@ class SQLEngine(schema.SchemaEngine): connection.commit() def proxy(self): - return lambda s, p = None: self.execute(s, p) + return lambda s, p = None: self.execute(s, p, commit=True) def connection(self): return self._pool.connect() @@ -192,7 +192,7 @@ class SQLEngine(schema.SchemaEngine): def post_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs): pass - def execute(self, statement, parameters, connection = None, echo = None, typemap = None, **kwargs): + def execute(self, statement, parameters, connection = None, echo = None, typemap = None, commit=False, **kwargs): if parameters is None: parameters = {} if echo is True or self.echo: @@ -200,8 +200,8 @@ class SQLEngine(schema.SchemaEngine): self.log(repr(parameters)) if connection is None: - poolconn = self.connection() - c = poolconn.cursor() + connection = self.connection() + c = connection.cursor() else: c = connection.cursor() @@ -211,6 +211,8 @@ class SQLEngine(schema.SchemaEngine): else: self._execute(c, statement, parameters) self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs) + if commit: + connection.commit() return ResultProxy(c, self, typemap = typemap) def _execute(self, c, statement, parameters): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 68593e810b..967fa213ad 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -165,8 +165,12 @@ class Column(SchemaItem): c._impl = self.engine.columnimpl(c) return c - def accept_visitor(self, visitor): - return visitor.visit_column(self) + def accept_visitor(self, visitor): + if self.sequence is not None: + self.sequence.accept_visitor(visitor) + if self.foreign_key is not None: + self.foreign_key.accept_visitor(visitor) + visitor.visit_column(self) def __lt__(self, other): return self._impl.__lt__(other) def __le__(self, other): return self._impl.__le__(other) @@ -208,7 +212,10 @@ class ForeignKey(SchemaItem): return self._column column = property(lambda s: s._init_column()) - + + def accept_visitor(self, visitor): + visitor.visit_foreign_key(self) + def _set_parent(self, column): self.parent = column self.parent.foreign_key = self @@ -216,11 +223,12 @@ class ForeignKey(SchemaItem): class Sequence(SchemaItem): """represents a sequence, which applies to Oracle and Postgres databases.""" - def __init__(self, name, start = None, increment = None): + def __init__(self, name, start = None, increment = None, optional=False): self.name = name self.start = start self.increment = increment - def _set_parent(self, column, key): + self.optional=optional + def _set_parent(self, column): self.column = column self.column.sequence = self def accept_visitor(self, visitor): @@ -236,14 +244,14 @@ class SchemaEngine(object): raise NotImplementedError() class SchemaVisitor(object): - """base class for an object that traverses across Schema objects""" - - def visit_schema(self, schema):pass - def visit_table(self, table):pass - def visit_column(self, column):pass - def visit_foreign_key(self, join):pass - def visit_index(self, index):pass - def visit_sequence(self, sequence):pass + """base class for an object that traverses across Schema objects""" + + def visit_schema(self, schema):pass + def visit_table(self, table):pass + def visit_column(self, column):pass + def visit_foreign_key(self, join):pass + def visit_index(self, index):pass + def visit_sequence(self, sequence):pass \ No newline at end of file diff --git a/test/objectstore.py b/test/objectstore.py index 5b5889695d..f8aaf497e1 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -637,7 +637,7 @@ class SaveTest(AssertMixin): item.keywords.append(ik) objectstore.uow().commit() - + objectstore.clear() l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name]) self.assert_result(l, *data) diff --git a/test/tables.py b/test/tables.py index 845077716c..aceed904b6 100644 --- a/test/tables.py +++ b/test/tables.py @@ -25,12 +25,12 @@ elif DBTYPE == 'postgres': db = testbase.EngineAssert(db) users = Table('users', db, - Column('user_id', Integer, primary_key = True), + Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), Column('user_name', String(40)), ) addresses = Table('email_addresses', db, - Column('address_id', Integer, primary_key = True), + Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True), Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(40)), ) @@ -65,7 +65,6 @@ def create(): orderitems.create() keywords.create() itemkeywords.create() - db.commit() def drop(): itemkeywords.drop() @@ -74,7 +73,6 @@ def drop(): orders.drop() addresses.drop() users.drop() - db.commit() def delete(): itemkeywords.delete().execute() @@ -83,6 +81,7 @@ def delete(): orders.delete().execute() addresses.delete().execute() users.delete().execute() + db.commit() def data(): delete() -- 2.47.2