From: Mike Bayer Date: Tue, 31 Jul 2007 17:15:36 +0000 (+0000) Subject: - assurances that context.connection is safe to use by column default functions,... X-Git-Tag: rel_0_4beta1~132 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8f17efd7a3c337045299927f1a30cbbd013dd6b1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - assurances that context.connection is safe to use by column default functions, helps proposal for [ticket:703] --- diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 548494ff2c..a40ed9bdf4 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -185,9 +185,9 @@ class PGExecutionContext(default.DefaultExecutionContext): # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html ident = "c" + hex(random.randint(0, 65535))[2:] - return self.connection.connection.cursor(ident) + return self._connection.connection.cursor(ident) else: - return self.connection.connection.cursor() + return self._connection.connection.cursor() def get_result_proxy(self): if self._is_server_side(): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 642eeac627..ff2da1165c 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -275,8 +275,14 @@ class ExecutionContext(object): ExecutionContext should have these datamembers: connection - Connection object which initiated the call to the - dialect to create this ExecutionContext. + Connection object which can be freely used by default value generators + to execute SQL. This Connection should reference the same underlying + connection/transactional resources of root_connection. + + root_connection + Connection object which is the source of this ExecutionContext. This + Connection may have close_with_result=True set, in which case it can + only be used once. dialect dialect which created this ExecutionContext. @@ -515,12 +521,13 @@ class Connection(Connectable): The Connection object is **not** threadsafe. """ - def __init__(self, engine, connection=None, close_with_result=False): + def __init__(self, engine, connection=None, close_with_result=False, _branch=False): self.__engine = engine self.__connection = connection or engine.raw_connection() self.__transaction = None self.__close_with_result = close_with_result self.__savepoint_seq = 0 + self.__branch = _branch def _get_connection(self): try: @@ -530,9 +537,14 @@ class Connection(Connectable): def _branch(self): """return a new Connection which references this Connection's - engine and connection; but does not have close_with_result enabled.""" + engine and connection; but does not have close_with_result enabled, + and also whose close() method does nothing. + + This is used to execute "sub" statements within a single execution, + usually an INSERT statement. + """ - return Connection(self.__engine, self.__connection) + return Connection(self.__engine, self.__connection, _branch=True) engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated.") dialect = property(lambda s:s.__engine.dialect, doc="Dialect used by this Connection.") @@ -686,7 +698,8 @@ class Connection(Connectable): c = self.__connection except AttributeError: return - self.__connection.close() + if not self.__branch: + self.__connection.close() self.__connection = None del self.__connection @@ -757,7 +770,7 @@ class Connection(Connectable): else: self.__execute(context) self._autocommit(context.statement) - + def __execute(self, context): if context.parameters is None: if context.dialect.positional: @@ -1124,7 +1137,8 @@ class ResultProxy(object): self._rowcount = context.get_rowcount() self.close() - connection = property(lambda self:self.context.connection) + connection = property(lambda self:self.context.root_connection) + def _get_rowcount(self): if self._rowcount is not None: return self._rowcount @@ -1510,9 +1524,7 @@ class DefaultRunner(schema.SchemaVisitor): def __init__(self, context): self.context = context - # branch the connection so it doesnt close after result - self.connection = context.connection._branch() - + self.connection = self.context._connection._branch() dialect = property(lambda self:self.context.dialect) def get_column_default(self, column): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index a2e159639d..185387177d 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -145,7 +145,7 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): def __init__(self, dialect, connection, compiled=None, statement=None, parameters=None): self.dialect = dialect - self.connection = connection + self._connection = connection self.compiled = compiled self._postfetch_cols = util.Set() @@ -172,11 +172,15 @@ class DefaultExecutionContext(base.ExecutionContext): self.statement = self.statement.encode(self.dialect.encoding) self.cursor = self.create_cursor() - + engine = property(lambda s:s.connection.engine) isinsert = property(lambda s:s.compiled and s.compiled.isinsert) isupdate = property(lambda s:s.compiled and s.compiled.isupdate) + connection = property(lambda s:s._connection._branch()) + + root_connection = property(lambda s:s._connection) + def __encode_param_keys(self, params): """apply string encoding to the keys of dictionary-based bind parameters""" if self.dialect.positional or self.dialect.supports_unicode_statements(): @@ -218,7 +222,7 @@ class DefaultExecutionContext(base.ExecutionContext): return re.match(r'SELECT', self.statement.lstrip(), re.I) is not None def create_cursor(self): - return self.connection.connection.cursor() + return self._connection.connection.cursor() def pre_execution(self): self.pre_exec() diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 6c200232f2..a3fe8c07a5 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -18,11 +18,20 @@ class DefaultTest(PersistTest): x['x'] += 1 return x['x'] - def mydefault_with_ctx(ctx): - return ctx.compiled_parameters['col1'] + 10 - def myupdate_with_ctx(ctx): return len(ctx.compiled_parameters['col2']) + + def mydefault_using_connection(ctx): + conn = ctx.connection + try: + if db.engine.name == 'oracle': + return conn.execute("select 12 from dual").scalar() + else: + return conn.execute("select 12").scalar() + finally: + # ensure a "close()" on this connection does nothing, + # since its a "branched" connection + conn.close() use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' is_oracle = db.engine.name == 'oracle' @@ -76,7 +85,7 @@ class DefaultTest(PersistTest): Column('boolcol2', Boolean, default=False), # python function which uses ExecutionContext - Column('col7', Integer, default=mydefault_with_ctx, onupdate=myupdate_with_ctx), + Column('col7', Integer, default=mydefault_using_connection, onupdate=myupdate_with_ctx), # python builtin Column('col8', Date, default=datetime.date.today, onupdate=datetime.date.today) @@ -119,7 +128,7 @@ class DefaultTest(PersistTest): print "Currenttime "+ repr(ctexec) l = t.select().execute() today = datetime.date.today() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 61, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 62, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 63, today)]) + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (52, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today)]) def testinsertvalues(self): t.insert(values={'col3':50}).execute() @@ -181,7 +190,7 @@ class AutoIncrementTest(PersistTest): nonai_table = Table("aitest", meta, Column('id', Integer, autoincrement=False, primary_key=True), Column('data', String(20))) - nonai_table.create() + nonai_table.create(checkfirst=True) try: try: # postgres will fail on first row, mysql fails on second row @@ -201,7 +210,7 @@ class AutoIncrementTest(PersistTest): table = Table("aitest", meta, Column('id', Integer, primary_key=True), Column('data', String(20))) - table.create() + table.create(checkfirst=True) try: table.insert().execute(data='row 1') table.insert().execute(data='row 2') @@ -216,7 +225,7 @@ class AutoIncrementTest(PersistTest): table = Table("aitest", meta, Column('id', Integer, primary_key=True), Column('data', String(20))) - table.create() + table.create(checkfirst=True) try: # simulate working on a table that doesn't already exist