]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 20:16:23 +0000 (20:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 20:16:23 +0000 (20:16 +0000)
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/sql.py
test/tables.py

index bb1f3c6aea407f496fc3ddce03a703c4e44707c8..0ef64d480adc9b1857f9ed61eb1f25143b747e57 100644 (file)
@@ -40,7 +40,7 @@ class OracleDateTime(sqltypes.DateTime):
         return "DATE"
 class OracleText(sqltypes.TEXT):
     def get_col_spec(self):
-        return "TEXT"
+        return "CLOB"
 class OracleString(sqltypes.String):
     def get_col_spec(self):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
index 86116f83fb918d93ab4f1f4af5f0c3dd9f14607c..f3648e69c477fc30ca66f2838a4eae71011b86be 100644 (file)
@@ -117,7 +117,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
             return None
             
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
-        # if a sequence was explicitly defined we do it here
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
             if isinstance(parameters, list):
index 93569a05c476be5e81ebf65bb3d2244d3822c131..c910916d1ed06d71d8a8a148d0ad04d536eb8a9d 100644 (file)
@@ -124,7 +124,7 @@ class SQLEngine(schema.SchemaEngine):
         connection.commit()
 
     def proxy(self):
-        return lambda s, p = None: self.execute(s, p, commit=True)
+        return lambda s, p = None: self.execute(s, p)
 
     def connection(self):
         return self._pool.connect()
@@ -172,8 +172,6 @@ class SQLEngine(schema.SchemaEngine):
             self.do_rollback(self.context.transaction)
             self.context.transaction = None
             self.context.tcount = None
-        else:
-            self.do_rollback(self.connection())
             
     def commit(self):
         if self.context.transaction is not None:
@@ -183,8 +181,6 @@ class SQLEngine(schema.SchemaEngine):
                 self.do_commit(self.context.transaction)
                 self.context.transaction = None
                 self.context.tcount = None
-        else:
-            self.do_commit(self.connection())
             
     def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs):
         pass
@@ -202,19 +198,23 @@ class SQLEngine(schema.SchemaEngine):
         else:
             c = connection.cursor()
 
-        self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs)
-
-        if echo is True or self.echo:
-            self.log(statement)
-            self.log(repr(parameters))
-
-        if isinstance(parameters, list):
-            self._executemany(c, statement, parameters)
-        else:
-            self._execute(c, statement, parameters)
-        self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
-        if commit:
-            connection.commit()
+        try:
+            self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+
+            if echo is True or self.echo:
+                self.log(statement)
+                self.log(repr(parameters))
+            if isinstance(parameters, list):
+                self._executemany(c, statement, parameters)
+            else:
+                self._execute(c, statement, parameters)
+            self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs)
+            if commit or self.context.transaction is None:
+                self.do_commit(connection)
+        except:
+            self.do_rollback(connection)
+            # TODO: wrap DB exceptions ?
+            raise
         return ResultProxy(c, self, typemap = typemap)
 
     def _execute(self, c, statement, parameters):
@@ -247,7 +247,18 @@ class ResultProxy:
                 i+=1
 
     def _get_col(self, row, key):
-        rec = self.props[key.lower()]
+        if isinstance(key, schema.Column):
+            try:
+                rec = self.props[key.label.lower()]
+            except KeyError:
+                try:
+                    rec = self.props[key.key.lower()]
+                except KeyError:
+                    rec = self.props[key.name.lower()]
+        elif isinstance(key, str):
+            rec = self.props[key.lower()]
+        else:
+            rec = self.props[key]
         return rec[0].convert_result_value(row[rec[1]])
         
     def fetchall(self):
index c90b50769da1f7a93f454b7a756142d983d5bf17..f541182aa876aaa2a0439fe25ca0caaae99b1ad1 100644 (file)
@@ -312,7 +312,7 @@ class Mapper(object):
             objectstore.uow().register_clean(value)
 
         if len(mappers):
-            return result + otherresults
+            return [result] + otherresults
         else:
             return result
 
@@ -375,9 +375,21 @@ class Mapper(object):
         in this case, the developer must insure that an adequate set of columns exists in the 
         rowset with which to build new object instances."""
         if arg is not None and isinstance(arg, sql.Select):
-            return self._select_statement(arg, **params)
+            return self.select_statement(arg, **params)
         else:
-            return self._select_whereclause(arg, **params)
+            return self.select_whereclause(arg, **params)
+
+    def select_whereclause(self, whereclause = None, order_by = None, **params):
+        statement = self._compile(whereclause, order_by = order_by)
+        return self.select_statement(statement, **params)
+
+    def select_statement(self, statement, **params):
+        statement.use_labels = True
+        return self.instances(statement.execute(**params))
+
+    def select_text(self, text, **params):
+        t = sql.text(text, engine=self.primarytable.engine)
+        return self.instances(t.execute(**params))
 
     def _getattrbycolumn(self, obj, column):
         try:
@@ -494,13 +506,6 @@ class Mapper(object):
         statement.use_labels = True
         return statement
 
-    def _select_whereclause(self, whereclause = None, order_by = None, **params):
-        statement = self._compile(whereclause, order_by = order_by)
-        return self._select_statement(statement, **params)
-
-    def _select_statement(self, statement, **params):
-        statement.use_labels = True
-        return self.instances(statement.execute(**params))
 
     def _identity_key(self, row):
         return objectstore.get_row_key(row, self.class_, self.primarytable, self.primary_keys[self.table])
@@ -539,7 +544,7 @@ class Mapper(object):
             # check if primary keys in the result are None - this indicates 
             # an instance of the object is not present in the row
             for col in self.primary_keys[self.table]:
-                if row[col.label] is None:
+                if row[col] is None:
                     return None
             # plugin point
             instance = self.extension.create_instance(self, row, imap, self.class_)
@@ -622,8 +627,7 @@ class ColumnProperty(MapperProperty):
 
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
-            instance.__dict__[self.key] = row[self.columns[0].label]
-            #setattr(instance, self.key, row[self.columns[0].label])
+            instance.__dict__[self.key] = row[self.columns[0]]
         
 
 class PropertyLoader(MapperProperty):
index 9b414f1818d9a934b67b527cc245f2c8033a401e..6081a7150d043fee4687184c7c1136aecfbf7c6f 100644 (file)
@@ -52,7 +52,7 @@ def get_row_key(row, class_, table, primary_keys):
     may be synonymous with the table argument or can be a larger construct containing that table.
     return value: a tuple object which is used as an identity key.
     """
-    return (class_, table, tuple([row[column.label] for column in primary_keys]))
+    return (class_, table, tuple([row[column] for column in primary_keys]))
 
 def begin():
     """begins a new UnitOfWork transaction.  the next commit will affect only
index b3a4ad29305e35dd4586e64a8f5e0e77fadb4751..a5a97a9e844baae658be25cfed7987901dcc7645 100644 (file)
@@ -121,8 +121,8 @@ def bindparam(key, value = None, type=None):
     else:
         return BindParamClause(key, value, type=type)
 
-def text(text):
-    return TextClause(text)
+def text(text, engine=None):
+    return TextClause(text, engine=engine)
 
 def null():
     return Null()
@@ -383,9 +383,10 @@ class BindParamClause(ClauseElement):
 class TextClause(ClauseElement):
     """represents any plain text WHERE clause or full SQL statement"""
     
-    def __init__(self, text = ""):
+    def __init__(self, text = "", engine=None):
         self.text = text
         self.parens = False
+        self.engine = engine
     def accept_visitor(self, visitor): 
         visitor.visit_textclause(self)
     def hash_key(self):
index aceed904b6c5242c335a294a0869e0f6ee67a31c..8bddce1d0bc6a481363b0707b6148665cdb173b0 100644 (file)
@@ -159,6 +159,8 @@ class Address(object):
         return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'user_id', None)) + " " + repr(self.email_address)
 
 class Order(object):
+    def __init__(self):
+        self.isopen=0
     def __repr__(self):
         return "Order: " + repr(self.description) + " " + repr(self.isopen) + " " + repr(getattr(self, 'items', None))