From f56d75c3ad9fbe7ff85b5b65698cd5696d12ee28 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 23 Oct 2005 20:16:23 +0000 Subject: [PATCH] --- lib/sqlalchemy/databases/oracle.py | 2 +- lib/sqlalchemy/databases/postgres.py | 1 - lib/sqlalchemy/engine.py | 49 +++++++++++++++++----------- lib/sqlalchemy/mapper.py | 30 +++++++++-------- lib/sqlalchemy/objectstore.py | 2 +- lib/sqlalchemy/sql.py | 7 ++-- test/tables.py | 2 ++ 7 files changed, 55 insertions(+), 38 deletions(-) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index bb1f3c6aea..0ef64d480a 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -40,7 +40,7 @@ class OracleDateTime(sqltypes.DateTime): return "DATE" class OracleText(sqltypes.TEXT): def get_col_spec(self): - return "TEXT" + return "CLOB" class OracleString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 86116f83fb..f3648e69c4 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -117,7 +117,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return None def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): - # if a sequence was explicitly defined we do it here if compiled is None: return if getattr(compiled, "isinsert", False): if isinstance(parameters, list): diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 93569a05c4..c910916d1e 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -124,7 +124,7 @@ class SQLEngine(schema.SchemaEngine): connection.commit() def proxy(self): - return lambda s, p = None: self.execute(s, p, commit=True) + return lambda s, p = None: self.execute(s, p) def connection(self): return self._pool.connect() @@ -172,8 +172,6 @@ class SQLEngine(schema.SchemaEngine): self.do_rollback(self.context.transaction) self.context.transaction = None self.context.tcount = None - else: - self.do_rollback(self.connection()) def commit(self): if self.context.transaction is not None: @@ -183,8 +181,6 @@ class SQLEngine(schema.SchemaEngine): self.do_commit(self.context.transaction) self.context.transaction = None self.context.tcount = None - else: - self.do_commit(self.connection()) def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs): pass @@ -202,19 +198,23 @@ class SQLEngine(schema.SchemaEngine): else: c = connection.cursor() - self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs) - - if echo is True or self.echo: - self.log(statement) - self.log(repr(parameters)) - - if isinstance(parameters, list): - self._executemany(c, statement, parameters) - else: - self._execute(c, statement, parameters) - self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs) - if commit: - connection.commit() + try: + self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs) + + if echo is True or self.echo: + self.log(statement) + self.log(repr(parameters)) + if isinstance(parameters, list): + self._executemany(c, statement, parameters) + else: + self._execute(c, statement, parameters) + self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs) + if commit or self.context.transaction is None: + self.do_commit(connection) + except: + self.do_rollback(connection) + # TODO: wrap DB exceptions ? + raise return ResultProxy(c, self, typemap = typemap) def _execute(self, c, statement, parameters): @@ -247,7 +247,18 @@ class ResultProxy: i+=1 def _get_col(self, row, key): - rec = self.props[key.lower()] + if isinstance(key, schema.Column): + try: + rec = self.props[key.label.lower()] + except KeyError: + try: + rec = self.props[key.key.lower()] + except KeyError: + rec = self.props[key.name.lower()] + elif isinstance(key, str): + rec = self.props[key.lower()] + else: + rec = self.props[key] return rec[0].convert_result_value(row[rec[1]]) def fetchall(self): diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index c90b50769d..f541182aa8 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -312,7 +312,7 @@ class Mapper(object): objectstore.uow().register_clean(value) if len(mappers): - return result + otherresults + return [result] + otherresults else: return result @@ -375,9 +375,21 @@ class Mapper(object): in this case, the developer must insure that an adequate set of columns exists in the rowset with which to build new object instances.""" if arg is not None and isinstance(arg, sql.Select): - return self._select_statement(arg, **params) + return self.select_statement(arg, **params) else: - return self._select_whereclause(arg, **params) + return self.select_whereclause(arg, **params) + + def select_whereclause(self, whereclause = None, order_by = None, **params): + statement = self._compile(whereclause, order_by = order_by) + return self.select_statement(statement, **params) + + def select_statement(self, statement, **params): + statement.use_labels = True + return self.instances(statement.execute(**params)) + + def select_text(self, text, **params): + t = sql.text(text, engine=self.primarytable.engine) + return self.instances(t.execute(**params)) def _getattrbycolumn(self, obj, column): try: @@ -494,13 +506,6 @@ class Mapper(object): statement.use_labels = True return statement - def _select_whereclause(self, whereclause = None, order_by = None, **params): - statement = self._compile(whereclause, order_by = order_by) - return self._select_statement(statement, **params) - - def _select_statement(self, statement, **params): - statement.use_labels = True - return self.instances(statement.execute(**params)) def _identity_key(self, row): return objectstore.get_row_key(row, self.class_, self.primarytable, self.primary_keys[self.table]) @@ -539,7 +544,7 @@ class Mapper(object): # check if primary keys in the result are None - this indicates # an instance of the object is not present in the row for col in self.primary_keys[self.table]: - if row[col.label] is None: + if row[col] is None: return None # plugin point instance = self.extension.create_instance(self, row, imap, self.class_) @@ -622,8 +627,7 @@ class ColumnProperty(MapperProperty): def execute(self, instance, row, identitykey, imap, isnew): if isnew: - instance.__dict__[self.key] = row[self.columns[0].label] - #setattr(instance, self.key, row[self.columns[0].label]) + instance.__dict__[self.key] = row[self.columns[0]] class PropertyLoader(MapperProperty): diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index 9b414f1818..6081a7150d 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -52,7 +52,7 @@ def get_row_key(row, class_, table, primary_keys): may be synonymous with the table argument or can be a larger construct containing that table. return value: a tuple object which is used as an identity key. """ - return (class_, table, tuple([row[column.label] for column in primary_keys])) + return (class_, table, tuple([row[column] for column in primary_keys])) def begin(): """begins a new UnitOfWork transaction. the next commit will affect only diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index b3a4ad2930..a5a97a9e84 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -121,8 +121,8 @@ def bindparam(key, value = None, type=None): else: return BindParamClause(key, value, type=type) -def text(text): - return TextClause(text) +def text(text, engine=None): + return TextClause(text, engine=engine) def null(): return Null() @@ -383,9 +383,10 @@ class BindParamClause(ClauseElement): class TextClause(ClauseElement): """represents any plain text WHERE clause or full SQL statement""" - def __init__(self, text = ""): + def __init__(self, text = "", engine=None): self.text = text self.parens = False + self.engine = engine def accept_visitor(self, visitor): visitor.visit_textclause(self) def hash_key(self): diff --git a/test/tables.py b/test/tables.py index aceed904b6..8bddce1d0b 100644 --- a/test/tables.py +++ b/test/tables.py @@ -159,6 +159,8 @@ class Address(object): return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'user_id', None)) + " " + repr(self.email_address) class Order(object): + def __init__(self): + self.isopen=0 def __repr__(self): return "Order: " + repr(self.description) + " " + repr(self.isopen) + " " + repr(getattr(self, 'items', None)) -- 2.47.2