]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
initial oracle+zxjdbc support
authorPhilip Jenvey <pjenvey@underboss.org>
Mon, 27 Jul 2009 00:43:59 +0000 (00:43 +0000)
committerPhilip Jenvey <pjenvey@underboss.org>
Mon, 27 Jul 2009 00:43:59 +0000 (00:43 +0000)
lib/sqlalchemy/dialects/oracle/__init__.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/oracle/zxjdbc.py [new file with mode: 0644]
test/engine/test_transaction.py

index 7038fb3ec07d713e38fe027a86ba52f4266f7b54..3b4379ab704ebd811fc65f911feb032c0eb07ba7 100644 (file)
@@ -1,3 +1,3 @@
-from sqlalchemy.dialects.oracle import base, cx_oracle
+from sqlalchemy.dialects.oracle import base, cx_oracle, zxjdbc
 
-base.dialect = cx_oracle.dialect
\ No newline at end of file
+base.dialect = cx_oracle.dialect
index cc19541eb11df6c14b76f82d54537ba69f00cfbe..882fc05c71988975ac9d495415546afef8e8dc7b 100644 (file)
@@ -446,7 +446,7 @@ class OracleDefaultRunner(base.DefaultRunner):
     def visit_sequence(self, seq):
         return self.execute_string("SELECT " + 
                     self.dialect.identifier_preparer.format_sequence(seq) + 
-                    ".nextval FROM DUAL", {})
+                    ".nextval FROM DUAL", ())
 
 class OracleIdentifierPreparer(compiler.IdentifierPreparer):
     
@@ -505,19 +505,26 @@ class OracleDialect(default.DefaultDialect):
     def has_table(self, connection, table_name, schema=None):
         if not schema:
             schema = self.get_default_schema_name(connection)
-        cursor = connection.execute("""select table_name from all_tables where table_name=:name and owner=:schema_name""", {'name':self.denormalize_name(table_name), 'schema_name':self.denormalize_name(schema)})
+        cursor = connection.execute(
+            sql.text("SELECT table_name FROM all_tables "
+                     "WHERE table_name = :name AND owner = :schema_name"),
+            name=self.denormalize_name(table_name), schema_name=self.denormalize_name(schema))
         return cursor.fetchone() is not None
 
     def has_sequence(self, connection, sequence_name, schema=None):
         if not schema:
             schema = self.get_default_schema_name(connection)
-        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name and sequence_owner=:schema_name""", {'name':self.denormalize_name(sequence_name), 'schema_name':self.denormalize_name(schema)})
+        cursor = connection.execute(
+            sql.text("SELECT sequence_name FROM all_sequences "
+                     "WHERE sequence_name = :name AND sequence_owner = :schema_name"),
+            name=self.denormalize_name(sequence_name), schema_name=self.denormalize_name(schema))
         return cursor.fetchone() is not None
 
     def normalize_name(self, name):
         if name is None:
             return None
-        elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding)):
+        elif (name.upper() == name and
+              not self.identifier_preparer._requires_quotes(name.lower().decode(self.encoding))):
             return name.lower().decode(self.encoding)
         else:
             return name.decode(self.encoding)
@@ -536,11 +543,15 @@ class OracleDialect(default.DefaultDialect):
     def table_names(self, connection, schema):
         # note that table_names() isnt loading DBLINKed or synonym'ed tables
         if schema is None:
-            s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')"
-            cursor = connection.execute(s)
+            cursor = connection.execute(
+                "SELECT table_name FROM all_tables "
+                "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX')")
         else:
-            s = "select table_name from all_tables where nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
-            cursor = connection.execute(s, {'owner': self.denormalize_name(schema)})
+            s = sql.text(
+                "SELECT table_name FROM all_tables "
+                "WHERE nvl(tablespace_name, 'no tablespace') NOT IN ('SYSTEM', 'SYSAUX') "
+                "AND OWNER = :owner")
+            cursor = connection.execute(s, owner=self.denormalize_name(schema))
         return [self.normalize_name(row[0]) for row in cursor]
 
     def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
@@ -551,28 +562,26 @@ class OracleDialect(default.DefaultDialect):
         returns the actual name, owner, dblink name, and synonym name if found.
         """
 
-        sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
-                   from   ALL_SYNONYMS WHERE """
-
+        q = "SELECT owner, table_owner, table_name, db_link, synonym_name FROM all_synonyms WHERE "
         clauses = []
         params = {}
         if desired_synonym:
-            clauses.append("SYNONYM_NAME=:synonym_name")
+            clauses.append("synonym_name = :synonym_name")
             params['synonym_name'] = desired_synonym
         if desired_owner:
-            clauses.append("TABLE_OWNER=:desired_owner")
+            clauses.append("table_owner = :desired_owner")
             params['desired_owner'] = desired_owner
         if desired_table:
-            clauses.append("TABLE_NAME=:tname")
+            clauses.append("table_name = :tname")
             params['tname'] = desired_table
 
-        sql += " AND ".join(clauses)
+        q += " AND ".join(clauses)
 
-        result = connection.execute(sql, **params)
+        result = connection.execute(sql.text(q), **params)
         if desired_owner:
             row = result.fetchone()
             if row:
-                return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
+                return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name']
             else:
                 return None, None, None, None
         else:
@@ -581,7 +590,7 @@ class OracleDialect(default.DefaultDialect):
                 raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
             elif len(rows) == 1:
                 row = rows[0]
-                return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
+                return row['table_name'], row['table_owner'], row['db_link'], row['synonym_name']
             else:
                 return None, None, None, None
 
@@ -615,9 +624,8 @@ class OracleDialect(default.DefaultDialect):
     @reflection.cache
     def get_view_names(self, connection, schema=None, **kw):
         schema = self.denormalize_name(schema or self.get_default_schema_name(connection))
-        s = "select view_name from all_views where OWNER = :owner"
-        cursor = connection.execute(s,
-                {'owner':self.denormalize_name(schema)})
+        s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
+        cursor = connection.execute(s, owner=self.denormalize_name(schema))
         return [self.normalize_name(row[0]) for row in cursor]
 
     @reflection.cache
@@ -641,14 +649,13 @@ class OracleDialect(default.DefaultDialect):
                                           resolve_synonyms, dblink,
                                           info_cache=info_cache)
         columns = []
-        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, "
-                                "DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s "
-                                "where TABLE_NAME = :table_name and OWNER = :owner" % 
-                                {'dblink':dblink}, {'table_name':table_name, 'owner':schema}
-                                )
+        c = connection.execute(sql.text(
+                "SELECT column_name, data_type, data_length, data_precision, data_scale, "
+                "nullable, data_default FROM ALL_TAB_COLUMNS%(dblink)s "
+                "WHERE table_name = :table_name AND owner = :owner" % {'dblink': dblink}),
+                               table_name=table_name, owner=schema)
 
         for row in c:
-
             (colname, coltype, length, precision, scale, nullable, default) = \
                 (self.normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
 
@@ -695,20 +702,18 @@ class OracleDialect(default.DefaultDialect):
                                           resolve_synonyms, dblink,
                                           info_cache=info_cache)
         indexes = []
-        q = """
-        SELECT a.INDEX_NAME, a.COLUMN_NAME, b.UNIQUENESS
+        q = sql.text("""
+        SELECT a.index_name, a.column_name, b.uniqueness
         FROM ALL_IND_COLUMNS%(dblink)s a
         INNER JOIN ALL_INDEXES%(dblink)s b
-            ON a.INDEX_NAME = b.INDEX_NAME
-            AND a.TABLE_OWNER = b.TABLE_OWNER
-            AND a.TABLE_NAME = b.TABLE_NAME
-        WHERE a.TABLE_NAME = :table_name
-        AND a.TABLE_OWNER = :schema
-        ORDER BY a.INDEX_NAME, a.COLUMN_POSITION
-        """ % dict(dblink=dblink)
-        rp = connection.execute(q,
-            dict(table_name=self.denormalize_name(table_name),
-                 schema=self.denormalize_name(schema)))
+            ON a.index_name = b.index_name
+            AND a.table_owner = b.table_owner
+            AND a.table_name = b.table_name
+        WHERE a.table_name = :table_name
+        AND a.table_owner = :schema
+        ORDER BY a.index_name, a.column_position""" % {'dblink': dblink})
+        rp = connection.execute(q, table_name=self.denormalize_name(table_name),
+                                schema=self.denormalize_name(schema))
         indexes = []
         last_index_name = None
         pkeys = self.get_primary_keys(connection, table_name, schema,
@@ -732,7 +737,8 @@ class OracleDialect(default.DefaultDialect):
     def _get_constraint_data(self, connection, table_name, schema=None,
                             dblink='', **kw):
 
-        rp = connection.execute("""SELECT
+        rp = connection.execute(
+            sql.text("""SELECT
              ac.constraint_name,
              ac.constraint_type,
              loc.column_name AS local_column,
@@ -752,9 +758,8 @@ class OracleDialect(default.DefaultDialect):
            AND ac.r_owner = rem.owner(+)
            AND ac.r_constraint_name = rem.constraint_name(+)
            AND (rem.position IS NULL or loc.position=rem.position)
-           ORDER BY ac.constraint_name, loc.position"""
-           
-         % {'dblink':dblink}, {'table_name' : table_name, 'owner' : schema})
+           ORDER BY ac.constraint_name, loc.position""" % {'dblink': dblink}),
+            table_name=table_name, owner=schema)
         constraint_data = rp.fetchall()
         return constraint_data
 
@@ -875,11 +880,11 @@ class OracleDialect(default.DefaultDialect):
             self._prepare_reflection_args(connection, view_name, schema,
                                           resolve_synonyms, dblink,
                                           info_cache=info_cache)
-        s = """
+        s = sql.text("""
         SELECT text FROM all_views
         WHERE owner = :schema
         AND view_name = :view_name
-        """
+        """)
         rp = connection.execute(s,
                                 view_name=view_name, schema=schema).scalar()
         if rp:
diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py
new file mode 100644 (file)
index 0000000..a0ad088
--- /dev/null
@@ -0,0 +1,24 @@
+"""Support for the Oracle database via the zxjdbc JDBC connector."""
+import re
+
+from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector
+from sqlalchemy.dialects.oracle.base import OracleDialect
+
+class Oracle_jdbc(ZxJDBCConnector, OracleDialect):
+
+    jdbc_db_name = 'oracle'
+    jdbc_driver_name = 'oracle.jdbc.driver.OracleDriver'
+
+    def create_connect_args(self, url):
+        hostname = url.host
+        port = url.port or '1521'
+        dbname = url.database
+        jdbc_url = 'jdbc:oracle:thin:@%s:%s:%s' % (hostname, port, dbname)
+        return [[jdbc_url, url.username, url.password, self.jdbc_driver_name],
+                self._driver_kwargs()]
+        
+    def _get_server_version_info(self, connection):
+        version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1)
+        return tuple(int(x) for x in version.split('.'))
+        
+dialect = Oracle_jdbc
index bb7cf116c9acdee3aef31c78276c088a5c62f1df..8e3f3412d653b28c7049a6008f0a1a339cd41992 100644 (file)
@@ -178,6 +178,7 @@ class TransactionTest(TestBase):
         connection.close()
 
     @testing.requires.savepoints
+    @testing.crashes('oracle+zxjdbc', 'Errors out and causes subsequent tests to deadlock')
     def test_nested_subtransaction_commit(self):
         connection = testing.db.connect()
         transaction = connection.begin()