]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 16:56:34 +0000 (16:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 16:56:34 +0000 (16:56 +0000)
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/engine.py

index 5922863b44409dfcdf24082bacb78326448a9fb9..c9955611dee8802bc7f15767d40eaa2f88b77591 100644 (file)
@@ -22,6 +22,7 @@ import sqlalchemy.sql as sql
 import sqlalchemy.schema as schema
 import sqlalchemy.ansisql as ansisql
 from sqlalchemy.ansisql import *
+import sqlalchemy.types as sqltypes
 
 try:
     import cx_Oracle
@@ -64,18 +65,18 @@ colspecs = {
     sqltypes.CHAR: OracleChar,
 }
 
-def engine(**params):
-    return OracleSQLEngine(**params)
+def engine(*args, **params):
+    return OracleSQLEngine(*args, **params)
     
 class OracleSQLEngine(ansisql.ANSISQLEngine):
     def __init__(self, opts, use_ansi = True, module = None, **params):
         self._use_ansi = use_ansi
-        ansisql.ANSISQLEngine.__init__(self, **params)
-        self.opts = {}
+        self.opts = opts or {}
         if module is None:
             self.module = cx_Oracle
         else:
             self.module = module
+        ansisql.ANSISQLEngine.__init__(self, **params)
 
     def dbapi(self):
         return self.module
@@ -83,11 +84,6 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
     def connect_args(self):
         return [[], self.opts]
         
-    def compile(self, statement, bindparams):
-        compiler = OracleCompiler(self, statement, bindparams, use_ansi = self._use_ansi)
-        statement.accept_visitor(compiler)
-        return compiler
-
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
@@ -112,9 +108,25 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
             return None
 
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
-        pass
+        # if a sequence was explicitly defined we do it here
+       if True: return
+        if compiled is None: return
+        if getattr(compiled, "isinsert", False):
+            for primary_key in compiled.statement.table.primary_keys:
+                if primary_key.sequence is not None and not primary_key.sequence.optional and parameters[primary_key.key] is None:
+                    if echo is True or self.echo:
+                        self.log("select nextval('%s')" % primary_key.sequence.name)
+                    cursor.execute("select nextval('%s')" % primary_key.sequence.name)
+                    newid = cursor.fetchone()[0]
+                    parameters[primary_key.key] = newid
+
     def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
-        pass
+       if True: return
+        if compiled is None: return
+        if getattr(compiled, "isinsert", False):
+            table = compiled.statement.table
+            self.context.last_inserted_table = table
+            self.context.lastrowid = cursor.lastrowid
 
     def _executemany(self, c, statement, parameters):
         rowcount = 0
@@ -159,6 +171,13 @@ class OracleCompiler(ansisql.ANSICompiler):
         else:
             self.strings[column] = "%s.%s" % (column.table.name, column.name)
         
+    def visit_insert(self, insert):
+        for c in insert.table.primary_keys:
+            if c.sequence is not None and not c.sequence.optional:
+                self.bindparams[c.key] = None
+                #if not insert.parameters.has_key(c.key):
+                 #   insert.parameters[c.key] = sql.bindparam(c.key)
+        return ansisql.ANSICompiler.visit_insert(self, insert)
 
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column):
@@ -175,4 +194,6 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
 
     def visit_sequence(self, sequence):
         self.append("CREATE SEQUENCE %s" % sequence.name)
-        self.execute()
\ No newline at end of file
+       print "HI"
+        self.execute()
+       print "THERE"
index 9957fe315e1f6952f9a3385089cc745eeb9f82a7..10d4d8fff3254f42f89cc031d9634cd13528f299 100644 (file)
@@ -193,26 +193,34 @@ class SQLEngine(schema.SchemaEngine):
         pass
 
     def execute(self, statement, parameters, connection = None, echo = None, typemap = None, commit=False, **kwargs):
+       print "I AM HERE"
         if parameters is None:
             parameters = {}
         if echo is True or self.echo:
             self.log(statement)
             self.log(repr(parameters))
 
+       print "LOGGED"
         if connection is None:
             connection = self.connection()
+            #connection = self.dbapi().connect(**self.connect_args()[1])
             c = connection.cursor()
         else:
             c = connection.cursor()
 
+       print "CONNECTION"
         self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+       print "LALA"
         if isinstance(parameters, list):
             self._executemany(c, statement, parameters)
         else:
             self._execute(c, statement, parameters)
+       print "FOO"
         self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
         if commit:
+            print "HOHO"
             connection.commit()
+        print "EEP"
         return ResultProxy(c, self, typemap = typemap)
 
     def _execute(self, c, statement, parameters):