]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
sequences, oracle
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 16:36:31 +0000 (16:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 16:36:31 +0000 (16:36 +0000)
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
test/objectstore.py
test/tables.py

index 6b6fea83077e245637961d5cc5d6e2e2d95fc41d..5922863b44409dfcdf24082bacb78326448a9fb9 100644 (file)
@@ -23,22 +23,105 @@ import sqlalchemy.schema as schema
 import sqlalchemy.ansisql as ansisql
 from sqlalchemy.ansisql import *
 
+try:
+    import cx_Oracle
+except:
+    cx_Oracle = None
+        
+class OracleNumeric(sqltypes.Numeric):
+    def get_col_spec(self):
+        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+class OracleInteger(sqltypes.Integer):
+    def get_col_spec(self):
+        return "INTEGER"
+class OracleDateTime(sqltypes.DateTime):
+    def get_col_spec(self):
+        return "TIMESTAMP"
+class OracleText(sqltypes.TEXT):
+    def get_col_spec(self):
+        return "TEXT"
+class OracleString(sqltypes.String):
+    def get_col_spec(self):
+        return "VARCHAR(%(length)s)" % {'length' : self.length}
+class OracleChar(sqltypes.CHAR):
+    def get_col_spec(self):
+        return "CHAR(%(length)s)" % {'length' : self.length}
+class OracleBinary(sqltypes.Binary):
+    def get_col_spec(self):
+        return "BLOB"
+class OracleBoolean(sqltypes.Boolean):
+    def get_col_spec(self):
+        return "BOOLEAN"
+        
+colspecs = {
+    sqltypes.Integer : OracleInteger,
+    sqltypes.Numeric : OracleNumeric,
+    sqltypes.DateTime : OracleDateTime,
+    sqltypes.String : OracleString,
+    sqltypes.Binary : OracleBinary,
+    sqltypes.Boolean : OracleBoolean,
+    sqltypes.TEXT : OracleText,
+    sqltypes.CHAR: OracleChar,
+}
 
 def engine(**params):
     return OracleSQLEngine(**params)
     
 class OracleSQLEngine(ansisql.ANSISQLEngine):
-    def __init__(self, use_ansi = True, **params):
+    def __init__(self, opts, use_ansi = True, module = None, **params):
         self._use_ansi = use_ansi
         ansisql.ANSISQLEngine.__init__(self, **params)
+        self.opts = {}
+        if module is None:
+            self.module = cx_Oracle
+        else:
+            self.module = module
 
+    def dbapi(self):
+        return self.module
+
+    def connect_args(self):
+        return [[], self.opts]
+        
     def compile(self, statement, bindparams):
         compiler = OracleCompiler(self, statement, bindparams, use_ansi = self._use_ansi)
         statement.accept_visitor(compiler)
         return compiler
-        
-    def create_connection(self):
-        raise NotImplementedError()
+
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
+
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
+    def compiler(self, statement, bindparams):
+        return OracleCompiler(self, statement, bindparams)
+
+    def schemagenerator(self, proxy, **params):
+        return OracleSchemaGenerator(proxy, **params)
+
+    def reflecttable(self, table):
+        raise "not implemented"
+
+    def last_inserted_ids(self):
+        table = self.context.last_inserted_table
+        if self.context.lastrowid is not None and table is not None and len(table.primary_keys):
+            row = sql.select(table.primary_keys, table.rowid_column == self.context.lastrowid).execute().fetchone()
+            return [v for v in row]
+        else:
+            return None
+
+    def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+        pass
+    def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+        pass
+
+    def _executemany(self, c, statement, parameters):
+        rowcount = 0
+        for param in parameters:
+            c.execute(statement, param)
+            rowcount += c.rowcount
+        self.context.rowcount = rowcount
 
 class OracleCompiler(ansisql.ANSICompiler):
     """oracle compiler modifies the lexical structure of Select statements to work under 
@@ -77,3 +160,19 @@ class OracleCompiler(ansisql.ANSICompiler):
             self.strings[column] = "%s.%s" % (column.table.name, column.name)
         
 
+class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
+    def get_column_specification(self, column):
+        colspec = column.name
+        colspec += " " + column.type.get_col_spec()
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        if column.primary_key:
+            colspec += " PRIMARY KEY"
+        if column.foreign_key:
+            colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
+        return colspec
+
+    def visit_sequence(self, sequence):
+        self.append("CREATE SEQUENCE %s" % sequence.name)
+        self.execute()
\ No newline at end of file
index 547a2c0b78c9bbe61469c3f283d32edf143e2ca6..76d5248f771908b08a552558c5b53d7315d1fbab 100644 (file)
@@ -84,7 +84,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
     def connect_args(self):
         return [[], self.opts]
 
-
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
@@ -97,6 +96,9 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
     def schemagenerator(self, proxy, **params):
         return PGSchemaGenerator(proxy, **params)
 
+    def schemadropper(self, proxy, **params):
+        return PGSchemaDropper(proxy, **params)
+        
     def reflecttable(self, table):
         raise "not implemented"
         
@@ -109,21 +111,16 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
             return None
             
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
-        if True: return
         # if a sequence was explicitly defined we do it here
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
-            last_inserted_ids = []
             for primary_key in compiled.statement.table.primary_keys:
-                # pseudocode
-                if parameters[primary_key.key] is None:
-                    if echo is True:
-                        self.log(primary_key.sequence.text)
-                    res = cursor.execute(primary_key.sequence.text)
-                    newid = res.fetchrow()[0]
+                if primary_key.sequence is not None and not primary_key.sequence.optional and parameters[primary_key.key] is None:
+                    if echo is True or self.echo:
+                        self.log("select nextval('%s')" % primary_key.sequence.name)
+                    cursor.execute("select nextval('%s')" % primary_key.sequence.name)
+                    newid = cursor.fetchone()[0]
                     parameters[primary_key.key] = newid
-                    last_inserted_ids.append(newid)
-            self.context.last_inserted_ids = last_inserted_ids
 
     def _executemany(self, c, statement, parameters):
         """we need accurate rowcounts for updates, inserts and deletes.  psycopg2 is not nice enough
@@ -151,13 +148,21 @@ class PGCompiler(ansisql.ANSICompiler):
     def bindparam_string(self, name):
         return "%(" + name + ")s"
 
+    def visit_insert(self, insert):
+        for c in insert.table.primary_keys:
+            if c.sequence is not None and not c.sequence.optional:
+                self.bindparams[c.key] = None
+                #if not insert.parameters.has_key(c.key):
+                 #   insert.parameters[c.key] = sql.bindparam(c.key)
+        return ansisql.ANSICompiler.visit_insert(self, insert)
+        
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column):
         colspec = column.name
-        if column.primary_key and isinstance(column.type, types.Integer):
+        if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional):
             colspec += " SERIAL"
         else:
-            colspec += " " + column.column.type.get_col_spec()
+            colspec += " " + column.type.get_col_spec()
 
         if not column.nullable:
             colspec += " NOT NULL"
@@ -166,3 +171,14 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
         if column.foreign_key:
             colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
         return colspec
+
+    def visit_sequence(self, sequence):
+        if not sequence.optional:
+            self.append("CREATE SEQUENCE %s" % sequence.name)
+            self.execute()
+            
+class PGSchemaDropper(ansisql.ANSISchemaDropper):
+    def visit_sequence(self, sequence):
+        if not sequence.optional:
+            self.append("DROP SEQUENCE %s" % sequence.name)
+            self.execute()
index 4e01d1684d5672bb0986b976c22ef375ebed83ca..9957fe315e1f6952f9a3385089cc745eeb9f82a7 100644 (file)
@@ -124,7 +124,7 @@ class SQLEngine(schema.SchemaEngine):
         connection.commit()
 
     def proxy(self):
-        return lambda s, p = None: self.execute(s, p)
+        return lambda s, p = None: self.execute(s, p, commit=True)
 
     def connection(self):
         return self._pool.connect()
@@ -192,7 +192,7 @@ class SQLEngine(schema.SchemaEngine):
     def post_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
         pass
 
-    def execute(self, statement, parameters, connection = None, echo = None, typemap = None, **kwargs):
+    def execute(self, statement, parameters, connection = None, echo = None, typemap = None, commit=False, **kwargs):
         if parameters is None:
             parameters = {}
         if echo is True or self.echo:
@@ -200,8 +200,8 @@ class SQLEngine(schema.SchemaEngine):
             self.log(repr(parameters))
 
         if connection is None:
-            poolconn = self.connection()
-            c = poolconn.cursor()
+            connection = self.connection()
+            c = connection.cursor()
         else:
             c = connection.cursor()
 
@@ -211,6 +211,8 @@ class SQLEngine(schema.SchemaEngine):
         else:
             self._execute(c, statement, parameters)
         self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+        if commit:
+            connection.commit()
         return ResultProxy(c, self, typemap = typemap)
 
     def _execute(self, c, statement, parameters):
index 68593e810b49cd3ff8a56efc513e3f77fb9f6103..967fa213ad7a69658bcc163e5c156508835bfa9d 100644 (file)
@@ -165,8 +165,12 @@ class Column(SchemaItem):
         c._impl = self.engine.columnimpl(c)
         return c
 
-    def accept_visitor(self, visitor): 
-        return visitor.visit_column(self)
+    def accept_visitor(self, visitor):
+        if self.sequence is not None:
+            self.sequence.accept_visitor(visitor)
+        if self.foreign_key is not None:
+            self.foreign_key.accept_visitor(visitor)
+        visitor.visit_column(self)
 
     def __lt__(self, other): return self._impl.__lt__(other)
     def __le__(self, other): return self._impl.__le__(other)
@@ -208,7 +212,10 @@ class ForeignKey(SchemaItem):
         return self._column
             
     column = property(lambda s: s._init_column())
-    
+
+    def accept_visitor(self, visitor):
+        visitor.visit_foreign_key(self)
+        
     def _set_parent(self, column):
         self.parent = column
         self.parent.foreign_key = self
@@ -216,11 +223,12 @@ class ForeignKey(SchemaItem):
         
 class Sequence(SchemaItem):
     """represents a sequence, which applies to Oracle and Postgres databases."""
-    def __init__(self, name, start = None, increment = None):
+    def __init__(self, name, start = None, increment = None, optional=False):
         self.name = name
         self.start = start
         self.increment = increment
-    def _set_parent(self, column, key):
+        self.optional=optional
+    def _set_parent(self, column):
         self.column = column
         self.column.sequence = self
     def accept_visitor(self, visitor):
@@ -236,14 +244,14 @@ class SchemaEngine(object):
         raise NotImplementedError()
         
 class SchemaVisitor(object):
-        """base class for an object that traverses across Schema objects"""
-
-        def visit_schema(self, schema):pass
-        def visit_table(self, table):pass
-        def visit_column(self, column):pass
-        def visit_foreign_key(self, join):pass
-        def visit_index(self, index):pass
-        def visit_sequence(self, sequence):pass
+    """base class for an object that traverses across Schema objects"""
+
+    def visit_schema(self, schema):pass
+    def visit_table(self, table):pass
+    def visit_column(self, column):pass
+    def visit_foreign_key(self, join):pass
+    def visit_index(self, index):pass
+    def visit_sequence(self, sequence):pass
 
             
             
\ No newline at end of file
index 5b5889695d48013108160086823734713262e364..f8aaf497e12a3c932f68b77ea9629ef617e52683 100644 (file)
@@ -637,7 +637,7 @@ class SaveTest(AssertMixin):
                 item.keywords.append(ik)
 
         objectstore.uow().commit()
-
+        objectstore.clear()
         l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name])
         self.assert_result(l, *data)
 
index 845077716cfa24191189707d3e39d0fb30a221bf..aceed904b6c5242c335a294a0869e0f6ee67a31c 100644 (file)
@@ -25,12 +25,12 @@ elif DBTYPE == 'postgres':
 db = testbase.EngineAssert(db)
 
 users = Table('users', db,
-    Column('user_id', Integer, primary_key = True),
+    Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
     Column('user_name', String(40)),
 )
 
 addresses = Table('email_addresses', db,
-    Column('address_id', Integer, primary_key = True),
+    Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
     Column('user_id', Integer, ForeignKey(users.c.user_id)),
     Column('email_address', String(40)),
 )
@@ -65,7 +65,6 @@ def create():
     orderitems.create()
     keywords.create()
     itemkeywords.create()
-    db.commit()
     
 def drop():
     itemkeywords.drop()
@@ -74,7 +73,6 @@ def drop():
     orders.drop()
     addresses.drop()
     users.drop()
-    db.commit()
     
 def delete():
     itemkeywords.delete().execute()
@@ -83,6 +81,7 @@ def delete():
     orders.delete().execute()
     addresses.delete().execute()
     users.delete().execute()
+    db.commit()
     
 def data():
     delete()