From: Mike Bayer
Date: Sun, 25 Feb 2007 22:44:52 +0000 (+0000)
Subject: migrated (most) docstrings to pep-257 format, docstring generator using straight...
X-Git-Tag: rel_0_3_6~52
X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=962c22c9eda7d2ab7dc0b41bd1c7a52cf0c9d008;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git
migrated (most) docstrings to pep-257 format, docstring generator using straight + trim() func
for now. applies most of [ticket:214], compliemnts of Lele Gaifax
---
diff --git a/doc/build/components/formatting.myt b/doc/build/components/formatting.myt
index 1668f7a60c..6b6cf8032b 100644
--- a/doc/build/components/formatting.myt
+++ b/doc/build/components/formatting.myt
@@ -53,7 +53,8 @@
<%method formatplain>
<%filter>
import re
- f = re.sub(r'\n[\s\t]*\n[\s\t]*', '
\n', f)
+# f = re.sub(r'\n[\s\t]*\n[\s\t]*', '
\n', f)
+ f = re.sub(r'\n[\s\t]*', '
\n', f)
f = "
" + f + "
"
return f
%filter>
diff --git a/doc/build/components/pydoc.myt b/doc/build/components/pydoc.myt
index 1425c5143f..48d06ff897 100644
--- a/doc/build/components/pydoc.myt
+++ b/doc/build/components/pydoc.myt
@@ -2,6 +2,33 @@
<%global>
import docstring
+import sys
+
+def trim(docstring):
+ if not docstring:
+ return ''
+ # Convert tabs to spaces (following the normal Python rules)
+ # and split into a list of lines:
+ lines = docstring.expandtabs().splitlines()
+ # Determine minimum indentation (first line doesn't count):
+ indent = sys.maxint
+ for line in lines[1:]:
+ stripped = line.lstrip()
+ if stripped:
+ indent = min(indent, len(line) - len(stripped))
+ # Remove indentation (first line is special):
+ trimmed = [lines[0].strip()]
+ if indent < sys.maxint:
+ for line in lines[1:]:
+ trimmed.append(line[indent:].rstrip())
+ # Strip off trailing and leading blank lines:
+ while trimmed and not trimmed[-1]:
+ trimmed.pop()
+ while trimmed and not trimmed[0]:
+ trimmed.pop(0)
+ # Return a single string:
+ return '\n'.join(trimmed)
+
%global>
<%method obj_doc>
@@ -25,7 +52,7 @@ else:
%init>
<&|formatting.myt:section, toc=toc, path=obj.toc_path, description=htmldescription &>
-<&|formatting.myt:formatplain&><% obj.doc %>&>
+<% trim(obj.doc) |h %>
% if not obj.isclass and obj.functions:
@@ -62,7 +89,7 @@ else:
<% func.name %>(<% ", ".join(map(lambda k: "%s" % k, func.arglist))%>)
- <&|formatting.myt:formatplain&><% func.doc %>&>
+
<% trim(func.doc) |h %>
%method>
@@ -76,7 +103,7 @@ else:
<% prop.name %>
- <&|formatting.myt:formatplain&><% prop.doc %>&>
+
<% trim(prop.doc) |h%>
%method>
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 6471407918..7a167f42ed 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -4,26 +4,38 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""defines ANSI SQL operations. Contains default implementations for the abstract objects
-in the sql module."""
+"""Defines ANSI SQL operations.
+
+Contains default implementations for the abstract objects in the sql
+module.
+"""
from sqlalchemy import schema, sql, engine, util, sql_util
from sqlalchemy.engine import default
import string, re, sets, weakref
-ANSI_FUNCS = sets.ImmutableSet([
-'CURRENT_TIME',
-'CURRENT_TIMESTAMP',
-'CURRENT_DATE',
-'LOCALTIME',
-'LOCALTIMESTAMP',
-'CURRENT_USER',
-'SESSION_USER',
-'USER'
-])
-
-
-RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', 'as', 'asc', 'asymmetric', 'authorization', 'between', 'binary', 'both', 'case', 'cast', 'check', 'collate', 'column', 'constraint', 'create', 'cross', 'current_date', 'current_role', 'current_time', 'current_timestamp', 'current_user', 'default', 'deferrable', 'desc', 'distinct', 'do', 'else', 'end', 'except', 'false', 'for', 'foreign', 'freeze', 'from', 'full', 'grant', 'group', 'having', 'ilike', 'in', 'initially', 'inner', 'intersect', 'into', 'is', 'isnull', 'join', 'leading', 'left', 'like', 'limit', 'localtime', 'localtimestamp', 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', 'placing', 'primary', 'references', 'right', 'select', 'session_user', 'similar', 'some', 'symmetric', 'table', 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', 'using', 'verbose', 'when', 'where'])
+ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
+ 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
+ 'SESSION_USER', 'USER'])
+
+
+RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array',
+ 'as', 'asc', 'asymmetric', 'authorization', 'between',
+ 'binary', 'both', 'case', 'cast', 'check', 'collate',
+ 'column', 'constraint', 'create', 'cross', 'current_date',
+ 'current_role', 'current_time', 'current_timestamp',
+ 'current_user', 'default', 'deferrable', 'desc',
+ 'distinct', 'do', 'else', 'end', 'except', 'false',
+ 'for', 'foreign', 'freeze', 'from', 'full', 'grant',
+ 'group', 'having', 'ilike', 'in', 'initially', 'inner',
+ 'intersect', 'into', 'is', 'isnull', 'join', 'leading',
+ 'left', 'like', 'limit', 'localtime', 'localtimestamp',
+ 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
+ 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
+ 'placing', 'primary', 'references', 'right', 'select',
+ 'session_user', 'similar', 'some', 'symmetric', 'table',
+ 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
+ 'using', 'verbose', 'when', 'where'])
LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
@@ -33,7 +45,7 @@ class ANSIDialect(default.DefaultDialect):
super(ANSIDialect,self).__init__(**kwargs)
self.identifier_preparer = self.preparer()
self.cache_identifiers = cache_identifiers
-
+
def create_connect_args(self):
return ([],{})
@@ -50,87 +62,99 @@ class ANSIDialect(default.DefaultDialect):
return ANSICompiler(self, statement, parameters, **kwargs)
def preparer(self):
- """return an IdenfifierPreparer.
-
- This object is used to format table and column names including proper quoting and case conventions."""
+ """Return an IdentifierPreparer.
+
+ This object is used to format table and column names including
+ proper quoting and case conventions.
+ """
return ANSIIdentifierPreparer(self)
class ANSICompiler(sql.Compiled):
- """default implementation of Compiled, which compiles ClauseElements into ANSI-compliant SQL strings."""
+ """Default implementation of Compiled.
+
+ Compiles ClauseElements into ANSI-compliant SQL strings.
+ """
+
def __init__(self, dialect, statement, parameters=None, **kwargs):
- """constructs a new ANSICompiler object.
-
- dialect - Dialect to be used
-
- statement - ClauseElement to be compiled
+ """Construct a new ``ANSICompiler`` object.
+
+ dialect
+ Dialect to be used
+
+ statement
+ ClauseElement to be compiled
+
+ parameters
+ optional dictionary indicating a set of bind parameters
+ specified with this Compiled object. These parameters are
+ the *default* key/value pairs when the Compiled is executed,
+ and also may affect the actual compilation, as in the case
+ of an INSERT where the actual columns inserted will
+ correspond to the keys present in the parameters.
+ """
- parameters - optional dictionary indicating a set of bind parameters
- specified with this Compiled object. These parameters are the "default"
- key/value pairs when the Compiled is executed, and also may affect the
- actual compilation, as in the case of an INSERT where the actual columns
- inserted will correspond to the keys present in the parameters."""
sql.Compiled.__init__(self, dialect, statement, parameters, **kwargs)
-
+
# a dictionary of bind parameter keys to _BindParamClause instances.
self.binds = {}
# a dictionary which stores the string representation for every ClauseElement
# processed by this compiler.
self.strings = {}
-
+
# a dictionary which stores the string representation for ClauseElements
# processed by this compiler, which are to be used in the FROM clause
# of a select. items are often placed in "froms" as well as "strings"
# and sometimes with different representations.
self.froms = {}
-
- # slightly hacky. maps FROM clauses to WHERE clauses, and used in select
+
+ # slightly hacky. maps FROM clauses to WHERE clauses, and used in select
# generation to modify the WHERE clause of the select. currently a hack
# used by the oracle module.
self.wheres = {}
-
+
# when the compiler visits a SELECT statement, the clause object is appended
# to this stack. various visit operations will check this stack to determine
# additional choices (TODO: it seems to be all typemap stuff. shouldnt this only
# apply to the topmost-level SELECT statement ?)
self.select_stack = []
-
+
# a dictionary of result-set column names (strings) to TypeEngine instances,
# which will be passed to a ResultProxy and used for resultset-level value conversion
self.typemap = {}
-
+
# a dictionary of select columns mapped to their name or key
self.columns = {}
-
+
# True if this compiled represents an INSERT
self.isinsert = False
-
+
# True if this compiled represents an UPDATE
self.isupdate = False
-
+
# default formatting style for bind parameters
self.bindtemplate = ":%s"
-
+
# paramstyle from the dialect (comes from DBAPI)
self.paramstyle = dialect.paramstyle
-
+
# true if the paramstyle is positional
self.positional = dialect.positional
-
+
# a list of the compiled's bind parameter names, used to help
# formulate a positional argument list
self.positiontup = []
-
+
# an ANSIIdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
-
+
# for UPDATE and INSERT statements, a set of columns whos values are being set
# from a SQL expression (i.e., not one of the bind parameter values). if present,
# default-value logic in the Dialect knows not to fire off column defaults
# and also knows postfetching will be needed to get the values represented by these
# parameters.
self.inline_params = None
-
+
def after_compile(self):
# this re will search for params like :param
# it has a negative lookbehind for an extra ':' so that it doesnt match
@@ -163,19 +187,26 @@ class ANSICompiler(sql.Compiled):
return self.wheres.get(obj, None)
def get_params(self, **params):
- """returns a structure of bind parameters for this compiled object.
- This includes bind parameters that might be compiled in via the "values"
- argument of an Insert or Update statement object, and also the given **params.
- The keys inside of **params can be any key that matches the BindParameterClause
- objects compiled within this object. The output is dependent on the paramstyle
- of the DBAPI being used; if a named style, the return result will be a dictionary
- with keynames matching the compiled statement. If a positional style, the output
- will be a list, with an iterator that will return parameter
- values in an order corresponding to the bind positions in the compiled statement.
-
- for an executemany style of call, this method should be called for each element
- in the list of parameter groups that will ultimately be executed.
+ """Return a structure of bind parameters for this compiled object.
+
+ This includes bind parameters that might be compiled in via
+ the `values` argument of an ``Insert`` or ``Update`` statement
+ object, and also the given `**params`. The keys inside of
+ `**params` can be any key that matches the
+ ``BindParameterClause`` objects compiled within this object.
+
+ The output is dependent on the paramstyle of the DBAPI being
+ used; if a named style, the return result will be a dictionary
+ with keynames matching the compiled statement. If a
+ positional style, the output will be a list, with an iterator
+ that will return parameter values in an order corresponding to
+ the bind positions in the compiled statement.
+
+ For an executemany style of call, this method should be called
+ for each element in the list of parameter groups that will
+ ultimately be executed.
"""
+
if self.parameters is not None:
bindparams = self.parameters.copy()
else:
@@ -196,15 +227,18 @@ class ANSICompiler(sql.Compiled):
return d
def default_from(self):
- """called when a SELECT statement has no froms, and no FROM clause is to be appended.
- gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ """Called when a SELECT statement has no froms, and no FROM clause is to be appended.
+
+ Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
+ """
+
return ""
def visit_label(self, label):
if len(self.select_stack):
self.typemap.setdefault(label.name.lower(), label.obj.type)
self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label)
-
+
def visit_column(self, column):
if len(self.select_stack):
# if we are within a visit to a Select, set up the "typemap"
@@ -230,10 +264,10 @@ class ANSICompiler(sql.Compiled):
def visit_index(self, index):
self.strings[index] = index.name
-
+
def visit_typeclause(self, typeclause):
self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
-
+
def visit_textclause(self, textclause):
if textclause.parens and len(textclause.text):
self.strings[textclause] = "(" + textclause.text + ")"
@@ -242,22 +276,22 @@ class ANSICompiler(sql.Compiled):
self.froms[textclause] = textclause.text
if textclause.typemap is not None:
self.typemap.update(textclause.typemap)
-
+
def visit_null(self, null):
self.strings[null] = 'NULL'
-
+
def visit_compound(self, compound):
if compound.operator is None:
sep = " "
else:
sep = " " + compound.operator + " "
-
+
s = string.join([self.get_str(c) for c in compound.clauses], sep)
if compound.parens:
self.strings[compound] = "(" + s + ")"
else:
self.strings[compound] = s
-
+
def visit_clauselist(self, list):
if list.parens:
self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")"
@@ -272,13 +306,13 @@ class ANSICompiler(sql.Compiled):
self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")"
else:
self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ')
-
+
def visit_cast(self, cast):
if len(self.select_stack):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
-
+
def visit_function(self, func):
if len(self.select_stack):
self.typemap.setdefault(func.name, func.type)
@@ -288,7 +322,7 @@ class ANSICompiler(sql.Compiled):
else:
self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
self.froms[func] = self.strings[func]
-
+
def visit_compound_select(self, cs):
text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
group_by = self.get_str(cs.group_by_clause)
@@ -335,13 +369,12 @@ class ANSICompiler(sql.Compiled):
def bindparam_string(self, name):
return self.bindtemplate % name
-
+
def visit_alias(self, alias):
self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
self.strings[alias] = self.get_str(alias.original)
def visit_select(self, select):
-
# the actual list of columns to print in the SELECT column list.
inner_columns = util.OrderedDict()
@@ -375,15 +408,15 @@ class ANSICompiler(sql.Compiled):
co.accept_visitor(self)
inner_columns[self.get_str(co)] = co
self.select_stack.pop(-1)
-
+
collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
text = "SELECT "
text += self.visit_select_precolumns(select)
text += collist
-
+
whereclause = select.whereclause
-
+
froms = []
for f in select.froms:
@@ -408,17 +441,17 @@ class ANSICompiler(sql.Compiled):
# TODO: move this more into the oracle module
whereclause = sql.and_(w, whereclause)
self.visit_compound(whereclause)
-
+
t = self.get_from_text(f)
if t is not None:
froms.append(t)
-
+
if len(froms):
text += " \nFROM "
text += string.join(froms, ', ')
else:
text += self.default_from()
-
+
if whereclause is not None:
t = self.get_str(whereclause)
if t:
@@ -448,11 +481,16 @@ class ANSICompiler(sql.Compiled):
self.froms[select] = "(" + text + ")"
def visit_select_precolumns(self, select):
- """ called when building a SELECT statment, position is just before column list """
+ """Called when building a ``SELECT`` statement, position is just before column list."""
+
return select.distinct and "DISTINCT " or ""
def visit_select_postclauses(self, select):
- """ called when building a SELECT statement, position is after all other SELECT clauses. Most DB syntaxes put LIMIT/OFFSET here """
+ """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
+
+ Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
+ """
+
return (select.limit or select.offset) and self.limit_clause(select) or ""
def for_update_clause(self, select):
@@ -480,7 +518,7 @@ class ANSICompiler(sql.Compiled):
if join.right._group_parenthesized():
righttext = "(" + righttext + ")"
if join.isouter:
- self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext +
+ self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext +
" ON " + self.get_str(join.onclause))
else:
self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
@@ -488,31 +526,50 @@ class ANSICompiler(sql.Compiled):
self.strings[join] = self.froms[join]
def visit_insert_column_default(self, column, default, parameters):
- """called when visiting an Insert statement, for each column in the table that
- contains a ColumnDefault object. adds a blank 'placeholder' parameter so the
- Insert gets compiled with this column's name in its column and VALUES clauses."""
+ """Called when visiting an ``Insert`` statement.
+
+ For each column in the table that contains a ``ColumnDefault``
+ object, add a blank *placeholder* parameter so the ``Insert``
+ gets compiled with this column's name in its column and
+ ``VALUES`` clauses.
+ """
+
parameters.setdefault(column.key, None)
def visit_update_column_default(self, column, default, parameters):
- """called when visiting an Update statement, for each column in the table that
- contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the
- Update gets compiled with this column's name as one of its SET clauses."""
+ """Called when visiting an ``Update`` statement.
+
+ For each column in the table that contains a ``ColumnDefault``
+ object as an onupdate, add a blank *placeholder* parameter so
+ the ``Update`` gets compiled with this column's name as one of
+ its ``SET` clauses.
+ """
+
parameters.setdefault(column.key, None)
-
+
def visit_insert_sequence(self, column, sequence, parameters):
- """called when visiting an Insert statement, for each column in the table that
- contains a Sequence object. Overridden by compilers that support sequences to place
- a blank 'placeholder' parameter, so the Insert gets compiled with this column's
- name in its column and VALUES clauses."""
+ """Called when visiting an ``Insert`` statement.
+
+ This may be overridden compilers that support sequences to
+ place a blank *placeholder* parameter for each column in the
+ table that contains a Sequence object, so the Insert gets
+ compiled with this column's name in its column and ``VALUES``
+ clauses.
+ """
+
pass
-
+
def visit_insert_column(self, column, parameters):
- """called when visiting an Insert statement, for each column in the table
- that is a NULL insert into the table. Overridden by compilers who disallow
- NULL columns being set in an Insert where there is a default value on the column
- (i.e. postgres), to remove the column from the parameter list."""
+ """Called when visiting an ``Insert`` statement.
+
+ This may be overridden by compilers who disallow NULL columns
+ being set in an ``Insert`` where there is a default value on
+ the column (i.e. postgres), to remove the column for which
+ there is a NULL insert from the parameter list.
+ """
+
pass
-
+
def visit_insert(self, insert_stmt):
# scan the table's columns for defaults that have to be pre-set for an INSERT
# add these columns to the parameter list via visit_insert_XXX methods
@@ -528,7 +585,7 @@ class ANSICompiler(sql.Compiled):
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
c.accept_schema_visitor(vis)
-
+
self.isinsert = True
colparams = self._get_colparams(insert_stmt, default_params)
@@ -580,32 +637,36 @@ class ANSICompiler(sql.Compiled):
return "(" + self.get_str(p) + ")"
else:
return self.get_str(p)
-
+
text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
-
+
if update_stmt.whereclause:
text += " WHERE " + self.get_str(update_stmt.whereclause)
-
+
self.strings[update_stmt] = text
def _get_colparams(self, stmt, default_params):
- """organize UPDATE/INSERT SET/VALUES parameters into a list of tuples,
- each tuple containing the Column and a ClauseElement representing the
- value to be set (usually a _BindParamClause, but could also be other
- SQL expressions.)
-
- the list of tuples will determine the columns that are actually rendered
- into the SET/VALUES clause of the rendered UPDATE/INSERT statement. It will
- also determine how to generate the list/dictionary of bind parameters at
- execution time (i.e. get_params()).
-
- this list takes into account the "values" keyword specified to the statement,
- the parameters sent to this Compiled instance, and the default bind parameter
- values corresponding to the dialect's behavior for otherwise unspecified
- primary key columns.
+ """Organize ``UPDATE``/``INSERT`` ``SET``/``VALUES`` parameters into a list of tuples.
+
+ Each tuple will contain the ``Column`` and a ``ClauseElement``
+ representing the value to be set (usually a ``_BindParamClause``,
+ but could also be other SQL expressions.)
+
+ The list of tuples will determine the columns that are
+ actually rendered into the ``SET``/``VALUES`` clause of the
+ rendered ``UPDATE``/``INSERT`` statement. It will also
+ determine how to generate the list/dictionary of bind
+ parameters at execution time (i.e. ``get_params()``).
+
+ This list takes into account the `values` keyword specified
+ to the statement, the parameters sent to this Compiled
+ instance, and the default bind parameter values corresponding
+ to the dialect's behavior for otherwise unspecified primary
+ key columns.
"""
- # no parameters in the statement, no parameters in the
+
+ # no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if self.parameters is None and stmt.parameters is None:
return [(c, sql.bindparam(c.key, type=c.type)) for c in stmt.table.columns]
@@ -615,8 +676,8 @@ class ANSICompiler(sql.Compiled):
return stmt.table.columns.get(str(key), key)
else:
return key
-
- # if we have statement parameters - set defaults in the
+
+ # if we have statement parameters - set defaults in the
# compiled params
if self.parameters is None:
parameters = {}
@@ -642,12 +703,12 @@ class ANSICompiler(sql.Compiled):
def visit_delete(self, delete_stmt):
text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
-
+
if delete_stmt.whereclause:
text += " WHERE " + self.get_str(delete_stmt.whereclause)
-
+
self.strings[delete_stmt] = text
-
+
def __str__(self):
return self.get_str(self.statement)
@@ -663,7 +724,7 @@ class ANSISchemaBase(engine.SchemaIterator):
for c in table.constraints:
c.accept_schema_visitor(findalterables)
return alterables
-
+
class ANSISchemaGenerator(ANSISchemaBase):
def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
@@ -672,10 +733,10 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.connection = connection
self.preparer = self.engine.dialect.preparer()
self.dialect = self.engine.dialect
-
+
def get_column_specification(self, column, first_pk=False):
raise NotImplementedError()
-
+
def visit_metadata(self, metadata):
collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name, schema=t.schema))]
for table in collection:
@@ -683,18 +744,18 @@ class ANSISchemaGenerator(ANSISchemaBase):
if self.supports_alter():
for alterable in self.find_alterables(collection):
self.add_foreignkey(alterable)
-
+
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
column.default.accept_schema_visitor(self, traverse=False)
#if column.onupdate is not None:
# column.onupdate.accept_schema_visitor(visitor, traverse=False)
-
+
self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
-
+
separator = "\n"
-
+
# if only one primary key, specify it along with the column
first_pk = False
for column in table.columns:
@@ -718,7 +779,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
if hasattr(table, 'indexes'):
for index in table.indexes:
index.accept_schema_visitor(self, traverse=False)
-
+
def post_create_table(self, table):
return ''
@@ -746,7 +807,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
def visit_column_check_constraint(self, constraint):
self.append(" ")
self.append(" CHECK (%s)" % constraint.sqltext)
-
+
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
return
@@ -755,21 +816,21 @@ class ANSISchemaGenerator(ANSISchemaBase):
self.append("CONSTRAINT %s " % constraint.name)
self.append("PRIMARY KEY ")
self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
-
+
def supports_alter(self):
return True
-
+
def visit_foreign_key_constraint(self, constraint):
if constraint.use_alter and self.supports_alter():
return
self.append(", \n\t ")
self.define_foreign_key(constraint)
-
+
def add_foreignkey(self, constraint):
self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
self.define_foreign_key(constraint)
self.execute()
-
+
def define_foreign_key(self, constraint):
if constraint.name is not None:
self.append("CONSTRAINT %s " % constraint.name)
@@ -801,7 +862,7 @@ class ANSISchemaGenerator(ANSISchemaBase):
% (index.name, self.preparer.format_table(index.table),
string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
self.execute()
-
+
class ANSISchemaDropper(ANSISchemaBase):
def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
@@ -829,7 +890,7 @@ class ANSISchemaDropper(ANSISchemaBase):
def drop_foreignkey(self, constraint):
self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name))
self.execute()
-
+
def visit_table(self, table):
for column in table.columns:
if column.default is not None:
@@ -842,58 +903,76 @@ class ANSIDefaultRunner(engine.DefaultRunner):
pass
class ANSIIdentifierPreparer(object):
- """handles quoting and case-folding of identifiers based on options"""
+ """Handle quoting and case-folding of identifiers based on options."""
+
def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
- """Constructs a new ANSIIdentifierPreparer object.
-
- initial_quote - Character that begins a delimited identifier
- final_quote - Caracter that ends a delimited identifier. defaults to initial_quote.
-
- omit_schema - prevent prepending schema name. useful for databases that do not support schemae
+ """Construct a new ``ANSIIdentifierPreparer`` object.
+
+ initial_quote
+ Character that begins a delimited identifier.
+
+ final_quote
+ Character that ends a delimited identifier. Defaults to `initial_quote`.
+
+ omit_schema
+ Prevent prepending schema name. Useful for databases that do
+ not support schemae.
"""
+
self.dialect = dialect
self.initial_quote = initial_quote
self.final_quote = final_quote or self.initial_quote
self.omit_schema = omit_schema
self.__strings = {}
+
def _escape_identifier(self, value):
- """escape an identifier.
-
- subclasses should override this to provide database-dependent escaping behavior."""
+ """Escape an identifier.
+
+ Subclasses should override this to provide database-dependent
+ escaping behavior.
+ """
+
return value.replace('"', '""')
-
+
def _quote_identifier(self, value):
- """quote an identifier.
-
- subclasses should override this to provide database-dependent quoting behavior."""
+ """Quote an identifier.
+
+ Subclasses should override this to provide database-dependent
+ quoting behavior.
+ """
+
return self.initial_quote + self._escape_identifier(value) + self.final_quote
-
+
def _fold_identifier_case(self, value):
- """fold the case of an identifier.
-
- subclassses should override this to provide database-dependent case folding behavior."""
+ """Fold the case of an identifier.
+
+ Subclasses should override this to provide database-dependent
+ case folding behavior.
+ """
+
return value
# ANSI SQL calls for the case of all unquoted identifiers to be folded to UPPER.
# some tests would need to be rewritten if this is done.
#return value.upper()
-
+
def _reserved_words(self):
return RESERVED_WORDS
def _legal_characters(self):
return LEGAL_CHARACTERS
-
+
def _illegal_initial_characters(self):
return ILLEGAL_INITIAL_CHARACTERS
-
+
def _requires_quotes(self, value, case_sensitive):
- """return true if the given identifier requires quoting."""
+ """Return True if the given identifier requires quoting."""
+
return \
value in self._reserved_words() \
or (value[0] in self._illegal_initial_characters()) \
or bool(len([x for x in str(value) if x not in self._legal_characters()])) \
or (case_sensitive and value.lower() != value)
-
+
def __generic_obj_format(self, obj, ident):
if getattr(obj, 'quote', False):
return self._quote_identifier(ident)
@@ -912,31 +991,33 @@ class ANSIIdentifierPreparer(object):
return self._quote_identifier(ident)
else:
return ident
-
+
def should_quote(self, object):
- return object.quote or self._requires_quotes(object.name, object.case_sensitive)
-
+ return object.quote or self._requires_quotes(object.name, object.case_sensitive)
+
def is_natural_case(self, object):
return object.quote or self._requires_quotes(object.name, object.case_sensitive)
-
+
def format_sequence(self, sequence):
return self.__generic_obj_format(sequence, sequence.name)
-
+
def format_label(self, label):
return self.__generic_obj_format(label, label.name)
def format_alias(self, alias):
return self.__generic_obj_format(alias, alias.name)
-
+
def format_table(self, table, use_schema=True):
- """Prepare a quoted table and schema name"""
+ """Prepare a quoted table and schema name."""
+
result = self.__generic_obj_format(table, table.name)
if use_schema and getattr(table, "schema", None):
result = self.__generic_obj_format(table, table.schema) + "." + result
return result
-
+
def format_column(self, column, use_table=False):
- """Prepare a quoted column name """
+ """Prepare a quoted column name."""
+
if not getattr(column, 'is_literal', False):
if use_table:
return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name)
@@ -948,8 +1029,8 @@ class ANSIIdentifierPreparer(object):
return self.format_table(column.table, use_schema=False) + "." + column.name
else:
return column.name
-
+
def format_column_with_table(self, column):
- """Prepare a quoted column name with table name"""
- return self.format_column(column, use_table=True)
+ """Prepare a quoted column name with table name."""
+ return self.format_column(column, use_table=True)
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index 5a25b12db3..91a0869c61 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -330,17 +330,23 @@ class FBCompiler(ansisql.ANSICompiler):
def visit_insert(self, insert):
"""Inserts are required to have the primary keys be explicitly present.
- mapper will by default not put them in the insert statement to comply
- with autoincrement fields that require they not be present. So,
- put them all in for all primary key columns."""
+
+ mapper will by default not put them in the insert statement
+ to comply with autoincrement fields that require they not be
+ present. So, put them all in for all primary key columns.
+ """
+
for c in insert.table.primary_key:
if not self.parameters.has_key(c.key):
self.parameters[c.key] = None
return ansisql.ANSICompiler.visit_insert(self, insert)
def visit_select_precolumns(self, select):
- """Called when building a SELECT statement, position is just before column list
- Firebird puts the limit and offset right after the select..."""
+ """Called when building a ``SELECT`` statement, position is just
+ before column list Firebird puts the limit and offset right
+ after the ``SELECT``...
+ """
+
result = ""
if select.limit:
result += " FIRST %d " % select.limit
@@ -351,7 +357,7 @@ class FBCompiler(ansisql.ANSICompiler):
return result
def limit_clause(self, select):
- """Already taken care of in the visit_select_precolumns method."""
+ """Already taken care of in the `visit_select_precolumns` method."""
return ""
diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py
index 5a7369ccda..54c47b6f42 100644
--- a/lib/sqlalchemy/databases/information_schema.py
+++ b/lib/sqlalchemy/databases/information_schema.py
@@ -85,6 +85,7 @@ class ISchema(object):
def __init__(self, engine):
self.engine = engine
self.cache = {}
+
def __getattr__(self, name):
if name not in self.cache:
# This is a bit of a hack.
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 8cde7179fe..254ea60131 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -1,36 +1,45 @@
# mssql.py
-"""
-notes:
- supports the pymssq, adodbapi and pyodbc interfaces
+"""MSSQL backend, thru either pymssq, adodbapi or pyodbc interfaces.
+
+* ``IDENTITY`` columns are supported by using SA ``schema.Sequence()``
+ objects. In other words::
+
+ Table('test', mss_engine,
+ Column('id', Integer, Sequence('blah',100,10), primary_key=True),
+ Column('name', String(20))
+ ).create()
- IDENTITY columns are supported by using SA schema.Sequence() objects. In other words:
- Table('test', mss_engine,
- Column('id', Integer, Sequence('blah',100,10), primary_key=True),
- Column('name', String(20))
- ).create()
+ would yield::
- would yield:
- CREATE TABLE test (
- id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
- name VARCHAR(20)
- )
- note that the start & increment values for sequences are optional and will default to 1,1
+ CREATE TABLE test (
+ id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
+ name VARCHAR(20)
+ )
- support for SET IDENTITY_INSERT ON mode (automagic on / off for INSERTs)
+ Note that the start & increment values for sequences are optional
+ and will default to 1,1.
- support for auto-fetching of @@IDENTITY on insert
+* Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
+ ``INSERT``s)
- select.limit implemented as SELECT TOP n
+* Support for auto-fetching of ``@@IDENTITY`` on ``INSERT``
+
+* ``select.limit`` implemented as ``SELECT TOP n``
Known issues / TODO:
- no support for more than one IDENTITY column per table
- no support for table reflection of IDENTITY columns with (seed,increment) values other than (1,1)
- no support for GUID type columns (yet)
- pymssql has problems with binary and unicode data that this module does NOT work around
- adodbapi fails testtypes.py unit test on unicode data too -- issue with the test?
+* No support for more than one ``IDENTITY`` column per table no
+
+* No support for table reflection of ``IDENTITY`` columns with
+ (seed,increment) values other than (1,1)
+
+* No support for ``GUID`` type columns (yet)
+
+* pymssql has problems with binary and unicode data that this module
+ does **not** work around adodbapi fails testtypes.py unit test on
+ unicode data too -- issue with the test?
"""
import sys, StringIO, string, types, re, datetime
@@ -138,6 +147,7 @@ class MSNumeric(sqltypes.Numeric):
class MSFloat(sqltypes.Float):
def get_col_spec(self):
return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
def convert_bind_param(self, value, dialect):
"""By converting to string, we can use Decimal types round-trip."""
return str(value)
@@ -197,14 +207,17 @@ class MSDate(sqltypes.Date):
class MSText(sqltypes.TEXT):
def get_col_spec(self):
return "TEXT"
+
class MSString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
-
class MSNVarchar(MSString):
- """NVARCHAR string, does unicode conversion if dialect.convert_encoding is true"""
+ """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True.
+ """
+
impl = sqltypes.Unicode
+
def get_col_spec(self):
if self.length:
return "NVARCHAR(%(length)s)" % {'length' : self.length}
@@ -214,36 +227,45 @@ class MSNVarchar(MSString):
class AdoMSNVarchar(MSNVarchar):
def convert_bind_param(self, value, dialect):
return value
+
def convert_result_value(self, value, dialect):
return value
class MSUnicode(sqltypes.Unicode):
- """Unicode subclass, does unicode conversion in all cases, uses NVARCHAR impl"""
+ """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl."""
+
impl = MSNVarchar
class AdoMSUnicode(MSUnicode):
impl = AdoMSNVarchar
+
def convert_bind_param(self, value, dialect):
return value
+
def convert_result_value(self, value, dialect):
return value
class MSChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
+
class MSNChar(sqltypes.NCHAR):
def get_col_spec(self):
return "NCHAR(%(length)s)" % {'length' : self.length}
+
class MSBinary(sqltypes.Binary):
def get_col_spec(self):
return "IMAGE"
+
class MSBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BIT"
+
def convert_result_value(self, value, dialect):
if value is None:
return None
return value and True or False
+
def convert_bind_param(self, value, dialect):
if value is True:
return 1
@@ -307,8 +329,12 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
super(MSSQLExecutionContext, self).__init__(dialect)
def pre_exec(self, engine, proxy, compiled, parameters, **kwargs):
- """ MS-SQL has a special mode for inserting non-NULL values into IDENTITY columns.
- Activate it if the feature is turned on and needed. """
+ """MS-SQL has a special mode for inserting non-NULL values
+ into IDENTITY columns.
+
+ Activate it if the feature is turned on and needed.
+ """
+
if getattr(compiled, "isinsert", False):
tbl = compiled.statement.table
if not hasattr(tbl, 'has_sequence'):
@@ -337,7 +363,11 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs)
def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
- """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """
+ """Turn off the INDENTITY_INSERT mode if it's been activated,
+ and fetch recently inserted IDENTIFY values (works only for
+ one column).
+ """
+
if getattr(compiled, "isinsert", False):
if self.IINSERT:
proxy("SET IDENTITY_INSERT %s OFF" % compiled.statement.table.name)
@@ -429,8 +459,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
c = self._pool.connect()
c.supportsTransactions = 0
return c
-
-
+
def dbapi(self):
return self.module
@@ -535,7 +564,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
if 'PRIMARY' in row[TC.c.constraint_type.name]:
table.primary_key.add(table.c[row[0]])
-
# Foreign key constraints
s = sql.select([C.c.column_name,
R.c.table_schema, R.c.table_name, R.c.column_name,
@@ -562,8 +590,6 @@ class MSSQLDialect(ansisql.ANSIDialect):
if fknm and scols:
table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
-
-
class PyMSSQLDialect(MSSQLDialect):
def do_rollback(self, connection):
@@ -578,7 +604,6 @@ class PyMSSQLDialect(MSSQLDialect):
if hasattr(self, 'query_timeout'):
dbmodule._mssql.set_query_timeout(self.query_timeout)
return r
-
## This code is leftover from the initial implementation, for reference
## def do_begin(self, connection):
@@ -611,7 +636,6 @@ class PyMSSQLDialect(MSSQLDialect):
## r.query("begin tran")
## r.fetch_array()
-
class MSSQLCompiler(ansisql.ANSICompiler):
def __init__(self, dialect, statement, parameters, **kwargs):
super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
@@ -627,7 +651,6 @@ class MSSQLCompiler(ansisql.ANSICompiler):
def limit_clause(self, select):
# Limit in mssql is after the select keyword; MSsql has no support for offset
return ""
-
def visit_table(self, table):
# alias schema-qualified tables
@@ -699,7 +722,6 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
return colspec
-
class MSSQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX " + index.table.name + "." + index.name)
@@ -711,9 +733,11 @@ class MSSQLDefaultRunner(ansisql.ANSIDefaultRunner):
class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
+
def _escape_identifier(self, value):
#TODO: determin MSSQL's escapeing rules
return value
+
def _fold_identifier_case(self, value):
#TODO: determin MSSQL's case folding rules
return value
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index c6bf2695fd..1cb41cf768 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -25,21 +25,24 @@ def kw_colspec(self, spec):
if self.zerofill:
spec += ' ZEROFILL'
return spec
-
+
class MSNumeric(sqltypes.Numeric):
def __init__(self, precision = 10, length = 2, **kw):
self.unsigned = 'unsigned' in kw
self.zerofill = 'zerofill' in kw
super(MSNumeric, self).__init__(precision, length)
+
def get_col_spec(self):
if self.precision is None:
return kw_colspec(self, "NUMERIC")
else:
return kw_colspec(self, "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+
class MSDecimal(MSNumeric):
def get_col_spec(self):
if self.precision is not None and self.length is not None:
return kw_colspec(self, "DECIMAL(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length})
+
class MSDouble(MSNumeric):
def __init__(self, precision=10, length=2, **kw):
if (precision is None and length is not None) or (precision is not None and length is None):
@@ -47,11 +50,13 @@ class MSDouble(MSNumeric):
self.unsigned = 'unsigned' in kw
self.zerofill = 'zerofill' in kw
super(MSDouble, self).__init__(precision, length)
+
def get_col_spec(self):
if self.precision is not None and self.length is not None:
return "DOUBLE(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
else:
return kw_colspec(self, "DOUBLE")
+
class MSFloat(sqltypes.Float):
def __init__(self, precision=10, length=None, **kw):
if length is not None:
@@ -59,6 +64,7 @@ class MSFloat(sqltypes.Float):
self.unsigned = 'unsigned' in kw
self.zerofill = 'zerofill' in kw
super(MSFloat, self).__init__(precision)
+
def get_col_spec(self):
if hasattr(self, 'length') and self.length is not None:
return kw_colspec(self, "FLOAT(%(precision)s,%(length)s)" % {'precision': self.precision, 'length' : self.length})
@@ -66,23 +72,27 @@ class MSFloat(sqltypes.Float):
return kw_colspec(self, "FLOAT(%(precision)s)" % {'precision': self.precision})
else:
return kw_colspec(self, "FLOAT")
+
class MSInteger(sqltypes.Integer):
def __init__(self, length=None, **kw):
self.length = length
self.unsigned = 'unsigned' in kw
self.zerofill = 'zerofill' in kw
super(MSInteger, self).__init__()
+
def get_col_spec(self):
if self.length is not None:
return kw_colspec(self, "INTEGER(%(length)s)" % {'length': self.length})
else:
return kw_colspec(self, "INTEGER")
+
class MSBigInteger(MSInteger):
def get_col_spec(self):
if self.length is not None:
return kw_colspec(self, "BIGINT(%(length)s)" % {'length': self.length})
else:
return kw_colspec(self, "BIGINT")
+
class MSSmallInteger(sqltypes.Smallinteger):
def __init__(self, length=None, **kw):
self.length = length
@@ -94,54 +104,65 @@ class MSSmallInteger(sqltypes.Smallinteger):
return kw_colspec(self, "SMALLINT(%(length)s)" % {'length': self.length})
else:
return kw_colspec(self, "SMALLINT")
+
class MSDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "DATETIME"
+
class MSDate(sqltypes.Date):
def get_col_spec(self):
return "DATE"
+
class MSTime(sqltypes.Time):
def get_col_spec(self):
return "TIME"
+
def convert_result_value(self, value, dialect):
# convert from a timedelta value
if value is not None:
return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60))
else:
return None
-
+
class MSText(sqltypes.TEXT):
def __init__(self, **kw):
self.binary = 'binary' in kw
super(MSText, self).__init__()
+
def get_col_spec(self):
return "TEXT"
+
class MSTinyText(MSText):
def get_col_spec(self):
if self.binary:
return "TEXT BINARY"
else:
return "TEXT"
+
class MSMediumText(MSText):
def get_col_spec(self):
if self.binary:
return "MEDIUMTEXT BINARY"
else:
return "MEDIUMTEXT"
+
class MSLongText(MSText):
def get_col_spec(self):
if self.binary:
return "LONGTEXT BINARY"
else:
return "LONGTEXT"
+
class MSString(sqltypes.String):
def __init__(self, length=None, *extra):
sqltypes.String.__init__(self, length=length)
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
+
class MSChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
+
class MSBinary(sqltypes.Binary):
def get_col_spec(self):
if self.length is not None and self.length <=255:
@@ -149,6 +170,7 @@ class MSBinary(sqltypes.Binary):
return "BINARY(%d)" % self.length
else:
return "BLOB"
+
def convert_result_value(self, value, dialect):
if value is None:
return None
@@ -158,7 +180,7 @@ class MSBinary(sqltypes.Binary):
class MSMediumBlob(MSBinary):
def get_col_spec(self):
return "MEDIUMBLOB"
-
+
class MSEnum(MSString):
def __init__(self, *enums):
self.__enums_hidden = enums
@@ -172,17 +194,20 @@ class MSEnum(MSString):
strip_enums.append(a)
self.enums = strip_enums
super(MSEnum, self).__init__(length)
+
def get_col_spec(self):
return "ENUM(%s)" % ",".join(self.__enums_hidden)
-
+
class MSBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
+
def convert_result_value(self, value, dialect):
if value is None:
return None
return value and True or False
+
def convert_bind_param(self, value, dialect):
if value is True:
return 1
@@ -192,7 +217,7 @@ class MSBoolean(sqltypes.Boolean):
return None
else:
return value and True or False
-
+
colspecs = {
# sqltypes.BIGinteger : MSInteger,
sqltypes.Integer : MSInteger,
@@ -215,7 +240,7 @@ ischema_names = {
'int' : MSInteger,
'mediumint' : MSInteger,
'smallint' : MSSmallInteger,
- 'tinyint' : MSSmallInteger,
+ 'tinyint' : MSSmallInteger,
'varchar' : MSString,
'char' : MSChar,
'text' : MSText,
@@ -245,7 +270,6 @@ def descriptor():
('host',"Hostname", None),
]}
-
class MySQLExecutionContext(default.DefaultExecutionContext):
def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
if getattr(compiled, "isinsert", False):
@@ -318,7 +342,6 @@ class MySQLDialect(ansisql.ANSIDialect):
if o.args[0] == 2006 or o.args[0] == 2014:
cursor.invalidate()
raise o
-
def do_rollback(self, connection):
# MySQL without InnoDB doesnt support rollback()
@@ -331,7 +354,7 @@ class MySQLDialect(ansisql.ANSIDialect):
if not hasattr(self, '_default_schema_name'):
self._default_schema_name = text("select database()", self).scalar()
return self._default_schema_name
-
+
def dbapi(self):
return self.module
@@ -345,7 +368,7 @@ class MySQLDialect(ansisql.ANSIDialect):
if isinstance(cs, array):
cs = cs.tostring()
case_sensitive = int(cs) == 0
-
+
if not case_sensitive:
table.name = table.name.lower()
table.metadata.tables[table.name]= table
@@ -364,7 +387,7 @@ class MySQLDialect(ansisql.ANSIDialect):
# these can come back as unicode if use_unicode=1 in the mysql connection
(name, type, nullable, primary_key, default) = (str(row[0]), str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4])
-
+
match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
col_type = match.group(1)
args = match.group(2)
@@ -391,7 +414,7 @@ class MySQLDialect(ansisql.ANSIDialect):
colargs= []
if default:
colargs.append(schema.PassiveDefault(sql.text(default)))
- table.append_column(schema.Column(name, coltype, *colargs,
+ table.append_column(schema.Column(name, coltype, *colargs,
**dict(primary_key=primary_key,
nullable=nullable,
)))
@@ -401,7 +424,7 @@ class MySQLDialect(ansisql.ANSIDialect):
if not found_table:
raise exceptions.NoSuchTableError(table.name)
-
+
def moretableinfo(self, connection, table):
"""Return (tabletype, {colname:foreignkey,...})
execute(SHOW CREATE TABLE child) =>
@@ -438,10 +461,8 @@ class MySQLDialect(ansisql.ANSIDialect):
table.append_constraint(constraint)
return tabletype
-
class MySQLCompiler(ansisql.ANSICompiler):
-
def visit_cast(self, cast):
"""hey ho MySQL supports almost no types at all for CAST"""
if (isinstance(cast.type, sqltypes.Date) or isinstance(cast.type, sqltypes.Time) or isinstance(cast.type, sqltypes.DateTime)):
@@ -467,7 +488,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
text += " \n LIMIT 18446744073709551615"
text += " OFFSET " + str(select.offset)
return text
-
+
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
t = column.type.engine_impl(self.engine)
@@ -495,6 +516,7 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
def visit_index(self, index):
self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
self.execute()
+
def drop_foreignkey(self, constraint):
self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name))
self.execute()
@@ -502,9 +524,11 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def __init__(self, dialect):
super(MySQLIdentifierPreparer, self).__init__(dialect, initial_quote='`')
+
def _escape_identifier(self, value):
#TODO: determin MySQL's escaping rules
return value
+
def _fold_identifier_case(self, value):
#TODO: determin MySQL's case folding rules
return value
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index d7b78d3dd9..d53de06548 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -18,22 +18,25 @@ except:
ORACLE_BINARY_TYPES = [getattr(cx_Oracle, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(cx_Oracle, k)]
-
class OracleNumeric(sqltypes.Numeric):
def get_col_spec(self):
if self.precision is None:
return "NUMERIC"
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+
class OracleInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
+
class OracleSmallInteger(sqltypes.Smallinteger):
def get_col_spec(self):
return "SMALLINT"
+
class OracleDateTime(sqltypes.DateTime):
def get_col_spec(self):
return "DATE"
+
# Note:
# Oracle DATE == DATETIME
# Oracle does not allow milliseconds in DATE
@@ -43,32 +46,40 @@ class OracleDateTime(sqltypes.DateTime):
class OracleTimestamp(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP"
+
def get_dbapi_type(self, dialect):
return dialect.TIMESTAMP
-
+
class OracleText(sqltypes.TEXT):
def get_col_spec(self):
return "CLOB"
+
class OracleString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
+
class OracleRaw(sqltypes.Binary):
def get_col_spec(self):
return "RAW(%(length)s)" % {'length' : self.length}
+
class OracleChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
+
class OracleBinary(sqltypes.Binary):
def get_dbapi_type(self, dbapi):
return dbapi.BINARY
+
def get_col_spec(self):
return "BLOB"
+
def convert_bind_param(self, value, dialect):
if value is None:
return None
else:
# this is RAWTOHEX
return ''.join(["%.2X" % ord(c) for c in value])
+
def convert_result_value(self, value, dialect):
if value is None:
return None
@@ -78,10 +89,12 @@ class OracleBinary(sqltypes.Binary):
class OracleBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "SMALLINT"
+
def convert_result_value(self, value, dialect):
if value is None:
return None
return value and True or False
+
def convert_bind_param(self, value, dialect):
if value is True:
return 1
@@ -90,9 +103,8 @@ class OracleBoolean(sqltypes.Boolean):
elif value is None:
return None
else:
- return value and True or False
+ return value and True or False
-
colspecs = {
sqltypes.Integer : OracleInteger,
sqltypes.Smallinteger : OracleSmallInteger,
@@ -121,8 +133,6 @@ ischema_names = {
'DOUBLE PRECISION' : OracleNumeric,
}
-
-
def descriptor():
return {'name':'oracle',
'description':'Oracle',
@@ -137,7 +147,7 @@ class OracleExecutionContext(default.DefaultExecutionContext):
super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters)
if self.dialect.auto_setinputsizes:
self.set_input_sizes(proxy(), parameters)
-
+
class OracleDialect(ansisql.ANSIDialect):
def __init__(self, use_ansi=True, auto_setinputsizes=False, module=None, threaded=True, **kwargs):
self.use_ansi = use_ansi
@@ -173,7 +183,7 @@ class OracleDialect(ansisql.ANSIDialect):
)
opts.update(url.query)
return ([], opts)
-
+
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
@@ -188,14 +198,16 @@ class OracleDialect(ansisql.ANSIDialect):
def compiler(self, statement, bindparams, **kwargs):
return OracleCompiler(self, statement, bindparams, **kwargs)
+
def schemagenerator(self, *args, **kwargs):
return OracleSchemaGenerator(*args, **kwargs)
+
def schemadropper(self, *args, **kwargs):
return OracleSchemaDropper(*args, **kwargs)
+
def defaultrunner(self, engine, proxy):
return OracleDefaultRunner(engine, proxy)
-
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
return bool( cursor.fetchone() is not None )
@@ -229,8 +241,12 @@ class OracleDialect(ansisql.ANSIDialect):
raise exceptions.AssertionError("There are multiple tables with name '%s' visible to the schema, you must specifiy owner" % name)
else:
return None
+
def _resolve_table_owner(self, connection, name, table, dblink=''):
- """locate the given table in the ALL_TAB_COLUMNS view, including searching for equivalent synonyms and dblinks"""
+ """Locate the given table in the ``ALL_TAB_COLUMNS`` view,
+ including searching for equivalent synonyms and dblinks.
+ """
+
c = connection.execute ("select distinct OWNER from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name" % {'dblink':dblink}, {'table_name':name})
rows = c.fetchall()
try:
@@ -239,10 +255,10 @@ class OracleDialect(ansisql.ANSIDialect):
except exceptions.SQLAlchemyError:
# locate synonyms
c = connection.execute ("""select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK
- from ALL_SYNONYMS%(dblink)s
+ from ALL_SYNONYMS%(dblink)s
where SYNONYM_NAME = :synonym_name
- and (DB_LINK IS NOT NULL
- or ((TABLE_NAME, TABLE_OWNER) in
+ and (DB_LINK IS NOT NULL
+ or ((TABLE_NAME, TABLE_OWNER) in
(select TABLE_NAME, OWNER from ALL_TAB_COLUMNS%(dblink)s)))""" % {'dblink':dblink},
{'synonym_name':name})
rows = c.fetchall()
@@ -262,20 +278,19 @@ class OracleDialect(ansisql.ANSIDialect):
return name, owner, dblink
raise
-
def reflecttable(self, connection, table):
preparer = self.identifier_preparer
if not preparer.should_quote(table):
name = table.name.upper()
else:
name = table.name
-
+
# search for table, including across synonyms and dblinks.
# locate the actual name of the table, the real owner, and any dblink clause needed.
actual_name, owner, dblink = self._resolve_table_owner(connection, name, table)
-
+
c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
-
+
while True:
row = c.fetchone()
if row is None:
@@ -305,20 +320,20 @@ class OracleDialect(ansisql.ANSIDialect):
coltype = ischema_names[coltype]
except KeyError:
raise exceptions.AssertionError("Cant get coltype for type '%s' on colname '%s'" % (coltype, colname))
-
+
colargs = []
if default is not None:
colargs.append(schema.PassiveDefault(sql.text(default)))
-
- # if name comes back as all upper, assume its case folded
- if (colname.upper() == colname):
+
+ # if name comes back as all upper, assume its case folded
+ if (colname.upper() == colname):
colname = colname.lower()
-
+
table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
if not len(table.columns):
raise exceptions.AssertionError("Couldn't find any column information for table %s" % actual_name)
-
+
c = connection.execute("""SELECT
ac.constraint_name,
ac.constraint_type,
@@ -339,13 +354,13 @@ class OracleDialect(ansisql.ANSIDialect):
-- order multiple primary keys correctly
ORDER BY ac.constraint_name, loc.position, rem.position"""
% {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner})
-
+
fks = {}
while True:
row = c.fetchone()
if row is None:
break
- #print "ROW:" , row
+ #print "ROW:" , row
(cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row
if cons_type == 'P':
table.primary_key.add(table.c[local_column])
@@ -389,12 +404,17 @@ class OracleDialect(ansisql.ANSIDialect):
OracleDialect.logger = logging.class_logger(OracleDialect)
class OracleCompiler(ansisql.ANSICompiler):
- """oracle compiler modifies the lexical structure of Select statements to work under
- non-ANSI configured Oracle databases, if the use_ansi flag is False."""
-
+ """Oracle compiler modifies the lexical structure of Select
+ statements to work under non-ANSI configured Oracle databases, if
+ the use_ansi flag is False.
+ """
+
def default_from(self):
- """called when a SELECT statement has no froms, and no FROM clause is to be appended.
- gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
+ """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
+
+ The Oracle compiler tacks a "FROM DUAL" to the statement.
+ """
+
return " FROM DUAL"
def apply_function_parens(self, func):
@@ -403,7 +423,7 @@ class OracleCompiler(ansisql.ANSICompiler):
def visit_join(self, join):
if self.dialect.use_ansi:
return ansisql.ANSICompiler.visit_join(self, join)
-
+
self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
self.strings[join] = self.froms[join]
@@ -421,42 +441,50 @@ class OracleCompiler(ansisql.ANSICompiler):
self.visit_compound(self.wheres[join])
def visit_insert_sequence(self, column, sequence, parameters):
- """this is the 'sequence' equivalent to ANSICompiler's 'visit_insert_column_default' which ensures
- that the column is present in the generated column list"""
+ """This is the `sequence` equivalent to ``ANSICompiler``'s
+ `visit_insert_column_default` which ensures that the column is
+ present in the generated column list.
+ """
+
parameters.setdefault(column.key, None)
-
+
def visit_alias(self, alias):
- """oracle doesnt like 'FROM table AS alias'. is the AS standard SQL??"""
+ """Oracle doesn't like ``FROM table AS alias``. Is the AS standard SQL??"""
+
self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
self.strings[alias] = self.get_str(alias.original)
-
+
def visit_column(self, column):
ansisql.ANSICompiler.visit_column(self, column)
if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
self.strings[column] = self.strings[column] + "(+)"
-
+
def visit_insert(self, insert):
- """inserts are required to have the primary keys be explicitly present.
- mapper will by default not put them in the insert statement to comply
- with autoincrement fields that require they not be present. so,
- put them all in for all primary key columns."""
+ """``INSERT``s are required to have the primary keys be explicitly present.
+
+ Mapper will by default not put them in the insert statement
+ to comply with autoincrement fields that require they not be
+ present. so, put them all in for all primary key columns.
+ """
+
for c in insert.table.primary_key:
if not self.parameters.has_key(c.key):
self.parameters[c.key] = None
return ansisql.ANSICompiler.visit_insert(self, insert)
def _TODO_visit_compound_select(self, select):
- """need to determine how to get LIMIT/OFFSET into a UNION for oracle"""
+ """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
+
if getattr(select, '_oracle_visit', False):
# cancel out the compiled order_by on the select
if hasattr(select, "order_by_clause"):
self.strings[select.order_by_clause] = ""
ansisql.ANSICompiler.visit_compound_select(self, select)
return
-
+
if select.limit is not None or select.offset is not None:
select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
+ # to use ROW_NUMBER(), an ORDER BY is required.
orderby = self.strings[select.order_by_clause]
if not orderby:
orderby = select.oid_column
@@ -478,10 +506,12 @@ class OracleCompiler(ansisql.ANSICompiler):
self.froms[select] = self.froms[limitselect]
else:
ansisql.ANSICompiler.visit_compound_select(self, select)
-
+
def visit_select(self, select):
- """looks for LIMIT and OFFSET in a select statement, and if so tries to wrap it in a
- subquery with row_number() criterion."""
+ """Look for ``LIMIT`` and OFFSET in a select statement, and if
+ so tries to wrap it in a subquery with ``row_number()`` criterion.
+ """
+
# TODO: put a real copy-container on Select and copy, or somehow make this
# not modify the Select statement
if getattr(select, '_oracle_visit', False):
@@ -493,7 +523,7 @@ class OracleCompiler(ansisql.ANSICompiler):
if select.limit is not None or select.offset is not None:
select._oracle_visit = True
- # to use ROW_NUMBER(), an ORDER BY is required.
+ # to use ROW_NUMBER(), an ORDER BY is required.
orderby = self.strings[select.order_by_clause]
if not orderby:
orderby = select.oid_column
@@ -512,7 +542,7 @@ class OracleCompiler(ansisql.ANSICompiler):
self.froms[select] = self.froms[limitselect]
else:
ansisql.ANSICompiler.visit_select(self, select)
-
+
def limit_clause(self, select):
return ""
@@ -539,7 +569,6 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-
class OracleSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
if self.engine.dialect.has_sequence(self.connection, sequence.name):
@@ -550,7 +579,7 @@ class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
def exec_default_sql(self, default):
c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile()
return self.proxy(str(c), c.get_params()).fetchone()[0]
-
+
def visit_sequence(self, seq):
return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index b76aafc222..83dac516a9 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -15,7 +15,7 @@ import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions
from sqlalchemy.databases import information_schema as ischema
-from sqlalchemy import *
+from sqlalchemy import *
import re
try:
@@ -42,24 +42,30 @@ class PGNumeric(sqltypes.Numeric):
return "NUMERIC"
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+
class PGFloat(sqltypes.Float):
def get_col_spec(self):
if not self.precision:
return "FLOAT"
else:
return "FLOAT(%(precision)s)" % {'precision': self.precision}
+
class PGInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
+
class PGSmallInteger(sqltypes.Smallinteger):
def get_col_spec(self):
return "SMALLINT"
+
class PGBigInteger(PGInteger):
def get_col_spec(self):
return "BIGINT"
+
class PG2DateTime(sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
class PG1DateTime(sqltypes.DateTime):
def convert_bind_param(self, value, dialect):
if value is not None:
@@ -73,6 +79,7 @@ class PG1DateTime(sqltypes.DateTime):
return psycopg.TimestampFromMx(value)
else:
return None
+
def convert_result_value(self, value, dialect):
if value is None:
return None
@@ -82,11 +89,14 @@ class PG1DateTime(sqltypes.DateTime):
return datetime.datetime(value.year, value.month, value.day,
value.hour, value.minute, seconds,
microseconds)
+
def get_col_spec(self):
return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
class PG2Date(sqltypes.Date):
def get_col_spec(self):
return "DATE"
+
class PG1Date(sqltypes.Date):
def convert_bind_param(self, value, dialect):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
@@ -95,14 +105,18 @@ class PG1Date(sqltypes.Date):
return psycopg.DateFromMx(value)
else:
return None
+
def convert_result_value(self, value, dialect):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
return value
+
def get_col_spec(self):
return "DATE"
+
class PG2Time(sqltypes.Time):
def get_col_spec(self):
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
class PG1Time(sqltypes.Time):
def convert_bind_param(self, value, dialect):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
@@ -111,32 +125,38 @@ class PG1Time(sqltypes.Time):
return psycopg.TimeFromMx(value)
else:
return None
+
def convert_result_value(self, value, dialect):
# TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime
return value
+
def get_col_spec(self):
return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE"
+
class PGInterval(sqltypes.TypeEngine):
def get_col_spec(self):
return "INTERVAL"
-
+
class PGText(sqltypes.TEXT):
def get_col_spec(self):
return "TEXT"
+
class PGString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
+
class PGChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
+
class PGBinary(sqltypes.Binary):
def get_col_spec(self):
return "BYTEA"
+
class PGBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
-
pg2_colspecs = {
sqltypes.Integer : PGInteger,
sqltypes.Smallinteger : PGSmallInteger,
@@ -214,7 +234,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
cursor = proxy(str(c), c.get_params())
row = cursor.fetchone()
self._last_inserted_ids = [v for v in row]
-
+
class PGDialect(ansisql.ANSIDialect):
def __init__(self, module=None, use_oids=False, use_information_schema=False, server_side_cursors=False, **params):
self.use_oids = use_oids
@@ -225,7 +245,7 @@ class PGDialect(ansisql.ANSIDialect):
self.module = psycopg
else:
self.module = module
- # figure psycopg version 1 or 2
+ # figure psycopg version 1 or 2
try:
if self.module.__version__.startswith('2'):
self.version = 2
@@ -238,7 +258,7 @@ class PGDialect(ansisql.ANSIDialect):
# produce consistent paramstyle even if psycopg2 module not present
if self.module is None:
self.paramstyle = 'pyformat'
-
+
def create_connect_args(self, url):
opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
if opts.has_key('port'):
@@ -265,23 +285,27 @@ class PGDialect(ansisql.ANSIDialect):
return sqltypes.adapt_type(typeobj, pg2_colspecs)
else:
return sqltypes.adapt_type(typeobj, pg1_colspecs)
-
+
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
+
def schemagenerator(self, *args, **kwargs):
return PGSchemaGenerator(*args, **kwargs)
+
def schemadropper(self, *args, **kwargs):
return PGSchemaDropper(*args, **kwargs)
+
def defaultrunner(self, engine, proxy):
return PGDefaultRunner(engine, proxy)
+
def preparer(self):
return PGIdentifierPreparer(self)
-
+
def get_default_schema_name(self, connection):
if not hasattr(self, '_default_schema_name'):
self._default_schema_name = connection.scalar("select current_schema()", None)
return self._default_schema_name
-
+
def last_inserted_ids(self):
if self.context.last_inserted_ids is None:
raise exceptions.InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled")
@@ -295,8 +319,12 @@ class PGDialect(ansisql.ANSIDialect):
return None
def do_executemany(self, c, statement, parameters, context=None):
- """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough
- to produce this correctly for an executemany, so we do our own executemany here."""
+ """We need accurate rowcounts for updates, inserts and deletes.
+
+ ``psycopg2`` is not nice enough to produce this correctly for
+ an executemany, so we do our own executemany here.
+ """
+
rowcount = 0
for param in parameters:
c.execute(statement, param)
@@ -318,7 +346,7 @@ class PGDialect(ansisql.ANSIDialect):
def has_sequence(self, connection, sequence_name):
cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name})
return bool(not not cursor.rowcount)
-
+
def reflecttable(self, connection, table):
if self.version == 2:
ischema_names = pg2_ischema_names
@@ -333,10 +361,10 @@ class PGDialect(ansisql.ANSIDialect):
schema_where_clause = "n.nspname = :schema"
else:
schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
-
+
## information schema in pg suffers from too many permissions' restrictions
## let us find out at the pg way what is needed...
-
+
SQL_COLS = """
SELECT a.attname,
pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -354,25 +382,25 @@ class PGDialect(ansisql.ANSIDialect):
) AND a.attnum > 0 AND NOT a.attisdropped
ORDER BY a.attnum
""" % schema_where_clause
-
+
s = text(SQL_COLS)
- c = connection.execute(s, table_name=table.name,
+ c = connection.execute(s, table_name=table.name,
schema=table.schema)
rows = c.fetchall()
-
- if not rows:
+
+ if not rows:
raise exceptions.NoSuchTableError(table.name)
-
+
for name, format_type, default, notnull, attnum, table_oid in rows:
- ## strip (30) from character varying(30)
+ ## strip (30) from character varying(30)
attype = re.search('([^\(]+)', format_type).group(1)
nullable = not notnull
-
+
try:
charlen = re.search('\(([\d,]+)\)', format_type).group(1)
except:
charlen = False
-
+
numericprec = False
numericscale = False
if attype == 'numeric':
@@ -400,7 +428,7 @@ class PGDialect(ansisql.ANSIDialect):
kwargs['timezone'] = True
elif attype == 'timestamp without time zone':
kwargs['timezone'] = False
-
+
coltype = ischema_names[attype]
coltype = coltype(*args, **kwargs)
colargs= []
@@ -413,31 +441,31 @@ class PGDialect(ansisql.ANSIDialect):
default = match.group(1) + sch + '.' + match.group(2) + match.group(3)
colargs.append(PassiveDefault(sql.text(default)))
table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
-
-
+
+
# Primary keys
PK_SQL = """
- SELECT attname FROM pg_attribute
+ SELECT attname FROM pg_attribute
WHERE attrelid = (
SELECT indexrelid FROM pg_index i
WHERE i.indrelid = :table
AND i.indisprimary = 't')
ORDER BY attnum
- """
+ """
t = text(PK_SQL)
c = connection.execute(t, table=table_oid)
- for row in c.fetchall():
+ for row in c.fetchall():
pk = row[0]
table.primary_key.add(table.c[pk])
-
+
# Foreign keys
FK_SQL = """
- SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
- FROM pg_catalog.pg_constraint r
- WHERE r.conrelid = :table AND r.contype = 'f'
+ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
+ FROM pg_catalog.pg_constraint r
+ WHERE r.conrelid = :table AND r.contype = 'f'
ORDER BY 1
"""
-
+
t = text(FK_SQL)
c = connection.execute(t, table=table_oid)
for conname, condef in c.fetchall():
@@ -448,10 +476,10 @@ class PGDialect(ansisql.ANSIDialect):
referred_schema = preparer._unquote_identifier(referred_schema)
referred_table = preparer._unquote_identifier(referred_table)
referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
-
+
refspec = []
if referred_schema is not None:
- schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
+ schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
autoload_with=connection)
for column in referred_columns:
refspec.append(".".join([referred_schema, referred_table, column]))
@@ -459,11 +487,10 @@ class PGDialect(ansisql.ANSIDialect):
schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
for column in referred_columns:
refspec.append(".".join([referred_table, column]))
-
+
table.append_constraint(ForeignKeyConstraint(constrained_columns, refspec, conname))
class PGCompiler(ansisql.ANSICompiler):
-
def visit_insert_column(self, column, parameters):
# all column primary key inserts must be explicitly present
if column.primary_key:
@@ -502,10 +529,9 @@ class PGCompiler(ansisql.ANSICompiler):
if isinstance(binary.type, sqltypes.String) and binary.operator == '+':
return '||'
else:
- return ansisql.ANSICompiler.binary_operator_string(self, binary)
-
+ return ansisql.ANSICompiler.binary_operator_string(self, binary)
+
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
-
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
@@ -527,7 +553,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)):
self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
self.execute()
-
+
class PGSchemaDropper(ansisql.ANSISchemaDropper):
def visit_sequence(self, sequence):
if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)):
@@ -543,7 +569,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
return c.fetchone()[0]
elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
sch = column.table.schema
- # TODO: this has to build into the Sequence object so we can get the quoting
+ # TODO: this has to build into the Sequence object so we can get the quoting
# logic from it
if sch is not None:
exc = "select nextval('\"%s\".\"%s_%s_seq\"')" % (sch, column.table.name, column.name)
@@ -555,7 +581,7 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
return ansisql.ANSIDefaultRunner.get_column_default(self, column)
else:
return ansisql.ANSIDefaultRunner.get_column_default(self, column)
-
+
def visit_sequence(self, seq):
if not seq.optional:
c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq))
@@ -566,9 +592,10 @@ class PGDefaultRunner(ansisql.ANSIDefaultRunner):
class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
def _fold_identifier_case(self, value):
return value.lower()
+
def _unquote_identifier(self, value):
if value[0] == self.initial_quote:
value = value[1:-1].replace('""','"')
return value
-
+
dialect = PGDialect
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 2ab3c0d5ad..b29be9eedd 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -31,18 +31,22 @@ class SLNumeric(sqltypes.Numeric):
return "NUMERIC"
else:
return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+
class SLInteger(sqltypes.Integer):
def get_col_spec(self):
return "INTEGER"
+
class SLSmallInteger(sqltypes.Smallinteger):
def get_col_spec(self):
return "SMALLINT"
+
class DateTimeMixin(object):
def convert_bind_param(self, value, dialect):
if value is not None:
return str(value)
else:
return None
+
def _cvt(self, value, dialect, fmt):
if value is None:
return None
@@ -52,49 +56,61 @@ class DateTimeMixin(object):
except ValueError:
(value, microsecond) = (value, 0)
return time.strptime(value, fmt)[0:6] + (microsecond,)
-
+
class SLDateTime(DateTimeMixin,sqltypes.DateTime):
def get_col_spec(self):
return "TIMESTAMP"
+
def convert_result_value(self, value, dialect):
tup = self._cvt(value, dialect, "%Y-%m-%d %H:%M:%S")
return tup and datetime.datetime(*tup)
+
class SLDate(DateTimeMixin, sqltypes.Date):
def get_col_spec(self):
return "DATE"
+
def convert_result_value(self, value, dialect):
tup = self._cvt(value, dialect, "%Y-%m-%d")
return tup and datetime.date(*tup[0:3])
+
class SLTime(DateTimeMixin, sqltypes.Time):
def get_col_spec(self):
return "TIME"
+
def convert_result_value(self, value, dialect):
tup = self._cvt(value, dialect, "%H:%M:%S")
return tup and datetime.time(*tup[3:7])
+
class SLText(sqltypes.TEXT):
def get_col_spec(self):
return "TEXT"
+
class SLString(sqltypes.String):
def get_col_spec(self):
return "VARCHAR(%(length)s)" % {'length' : self.length}
+
class SLChar(sqltypes.CHAR):
def get_col_spec(self):
return "CHAR(%(length)s)" % {'length' : self.length}
+
class SLBinary(sqltypes.Binary):
def get_col_spec(self):
return "BLOB"
+
class SLBoolean(sqltypes.Boolean):
def get_col_spec(self):
return "BOOLEAN"
+
def convert_bind_param(self, value, dialect):
if value is None:
return None
return value and 1 or 0
+
def convert_result_value(self, value, dialect):
if value is None:
return None
return value and True or False
-
+
colspecs = {
sqltypes.Integer : SLInteger,
sqltypes.Smallinteger : SLSmallInteger,
@@ -135,49 +151,56 @@ def descriptor():
('database', "Database Filename",None)
]}
-
class SQLiteExecutionContext(default.DefaultExecutionContext):
def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
if getattr(compiled, "isinsert", False):
self._last_inserted_ids = [proxy().lastrowid]
-
+
class SQLiteDialect(ansisql.ANSIDialect):
def __init__(self, **kwargs):
def vers(num):
return tuple([int(x) for x in num.split('.')])
self.supports_cast = (sqlite is not None and vers(sqlite.sqlite_version) >= vers("3.2.3"))
ansisql.ANSIDialect.__init__(self, **kwargs)
+
def compiler(self, statement, bindparams, **kwargs):
return SQLiteCompiler(self, statement, bindparams, **kwargs)
+
def schemagenerator(self, *args, **kwargs):
return SQLiteSchemaGenerator(*args, **kwargs)
+
def schemadropper(self, *args, **kwargs):
return SQLiteSchemaDropper(*args, **kwargs)
+
def preparer(self):
return SQLiteIdentifierPreparer(self)
+
def create_connect_args(self, url):
filename = url.database or ':memory:'
return ([filename], url.query)
+
def type_descriptor(self, typeobj):
return sqltypes.adapt_type(typeobj, colspecs)
+
def create_execution_context(self):
return SQLiteExecutionContext(self)
+
def last_inserted_ids(self):
return self.context.last_inserted_ids
-
+
def oid_column_name(self, column):
return "oid"
def dbapi(self):
return sqlite
-
+
def has_table(self, connection, table_name, schema=None):
cursor = connection.execute("PRAGMA table_info(" + table_name + ")", {})
row = cursor.fetchone()
-
+
# consume remaining rows, to work around: http://www.sqlite.org/cvstrac/tktview?tn=1884
while cursor.fetchone() is not None:pass
-
+
return (row is not None)
def reflecttable(self, connection, table):
@@ -198,7 +221,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
else:
coltype = "VARCHAR"
args = ''
-
+
#print "coltype: " + repr(coltype) + " args: " + repr(args)
coltype = pragma_names.get(coltype, SLString)
if args is not None:
@@ -210,10 +233,10 @@ class SQLiteDialect(ansisql.ANSIDialect):
if has_default:
colargs.append(PassiveDefault('?'))
table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
-
+
if not found_table:
raise exceptions.NoSuchTableError(table.name)
-
+
c = connection.execute("PRAGMA foreign_key_list(" + table.name + ")", {})
fks = {}
while True:
@@ -229,7 +252,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
except KeyError:
fk = ([],[])
fks[constraint_name] = fk
-
+
#print "row! " + repr([key for key in row.keys()]), repr(row)
# look up the table based on the given table's engine, not 'self',
# since it could be a ProxyEngine
@@ -241,7 +264,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
if refspec not in fk[1]:
fk[1].append(refspec)
for name, value in fks.iteritems():
- table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))
+ table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))
# check for UNIQUE indexes
c = connection.execute("PRAGMA index_list(" + table.name + ")", {})
unique_indexes = []
@@ -264,7 +287,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
# unique index that includes the pk is considered a multiple primary key
for col in cols:
table.primary_key.add(table.columns[col])
-
+
class SQLiteCompiler(ansisql.ANSICompiler):
def visit_cast(self, cast):
if self.dialect.supports_cast:
@@ -274,6 +297,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
# not sure if we want to set the typemap here...
self.typemap.setdefault("CAST", cast.type)
self.strings[cast] = self.strings[cast.clause]
+
def limit_clause(self, select):
text = ""
if select.limit is not None:
@@ -285,6 +309,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
else:
text += " OFFSET 0"
return text
+
def for_update_clause(self, select):
# sqlite has no "FOR UPDATE" AFAICT
return ''
@@ -298,7 +323,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def supports_alter(self):
return False
-
+
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
@@ -328,4 +353,4 @@ class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
dialect = SQLiteDialect
-poolclass = pool.SingletonThreadPool
+poolclass = pool.SingletonThreadPool
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index 220707eb34..c3651c88b9 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -10,82 +10,115 @@ from sqlalchemy.engine import strategies
import re
def engine_descriptors():
- """provides a listing of all the database implementations supported. this data
- is provided as a list of dictionaries, where each dictionary contains the following
- key/value pairs:
-
- name : the name of the engine, suitable for use in the create_engine function
-
- description: a plain description of the engine.
-
- arguments : a dictionary describing the name and description of each parameter
- used to connect to this engine's underlying DBAPI.
-
- This function is meant for usage in automated configuration tools that wish to
- query the user for database and connection information.
+ """Provide a listing of all the database implementations supported.
+
+ This data is provided as a list of dictionaries, where each
+ dictionary contains the following key/value pairs:
+
+ name
+ the name of the engine, suitable for use in the create_engine function
+
+ description
+ a plain description of the engine.
+
+ arguments
+ a dictionary describing the name and description of each
+ parameter used to connect to this engine's underlying DBAPI.
+
+ This function is meant for usage in automated configuration tools
+ that wish to query the user for database and connection
+ information.
"""
+
result = []
#for module in sqlalchemy.databases.__all__:
for module in ['sqlite', 'postgres', 'mysql']:
module = getattr(__import__('sqlalchemy.databases.%s' % module).databases, module)
result.append(module.descriptor())
return result
-
+
default_strategy = 'plain'
def create_engine(*args, **kwargs):
- """creates a new Engine instance. Using the given strategy name,
- locates that strategy and invokes its create() method to produce the Engine.
- The strategies themselves are instances of EngineStrategy, and the built in
- ones are present in the sqlalchemy.engine.strategies module. Current implementations
- include "plain" and "threadlocal". The default used by this function is "plain".
-
- "plain" provides support for a Connection object which can be used to execute SQL queries
- with a specific underlying DBAPI connection.
-
- "threadlocal" is similar to "plain" except that it adds support for a thread-local connection and
- transaction context, which allows a group of engine operations to participate using the same
- connection and transaction without the need for explicit passing of a Connection object.
-
- The standard method of specifying the engine is via URL as the first positional
- argument, to indicate the appropriate database dialect and connection arguments, with additional
- keyword arguments sent as options to the dialect and resulting Engine.
-
- The URL is in the form ://opt1=val1&opt2=val2.
- Where is a name such as "mysql", "oracle", "postgres", and the options indicate
- username, password, database, etc. Supported keynames include "username", "user", "password",
- "pw", "db", "database", "host", "filename".
-
- **kwargs represents options to be sent to the Engine itself as well as the components of the Engine,
- including the Dialect, the ConnectionProvider, and the Pool. A list of common options is as follows:
-
- pool=None : an instance of sqlalchemy.pool.DBProxy or sqlalchemy.pool.Pool to be used as the
- underlying source for connections (DBProxy/Pool is described in the previous section). If None,
- a default DBProxy will be created using the engine's own database module with the given
- arguments.
-
- echo=False : if True, the Engine will log all statements as well as a repr() of their
- parameter lists to the engines logger, which defaults to sys.stdout. A Engine instances'
- "echo" data member can be modified at any time to turn logging on and off. If set to the string
- 'debug', result rows will be printed to the standard output as well.
-
- logger=None : a file-like object where logging output can be sent, if echo is set to True.
- This defaults to sys.stdout.
-
- encoding='utf-8' : the encoding to be used when encoding/decoding Unicode strings
-
- convert_unicode=False : True if unicode conversion should be applied to all str types
-
- module=None : used by Oracle and Postgres, this is a reference to a DBAPI2 module to be used
- instead of the engine's default module. For Postgres, the default is psycopg2, or psycopg1 if
- 2 cannot be found. For Oracle, its cx_Oracle. For mysql, MySQLdb.
-
- use_ansi=True : used only by Oracle; when False, the Oracle driver attempts to support a
- particular "quirk" of some Oracle databases, that the LEFT OUTER JOIN SQL syntax is not
- supported, and the "Oracle join" syntax of using (+)= must be used
- in order to achieve a LEFT OUTER JOIN. Its advised that the Oracle database be configured to
- have full ANSI support instead of using this feature.
+ """Create a new Engine instance.
+
+ Using the given strategy name, locates that strategy and invokes
+ its create() method to produce the Engine. The strategies
+ themselves are instances of EngineStrategy, and the built in ones
+ are present in the sqlalchemy.engine.strategies module. Current
+ implementations include *plain* and *threadlocal*. The default
+ used by this function is *plain*.
+
+ *plain* provides support for a Connection object which can be used
+ to execute SQL queries with a specific underlying DBAPI connection.
+
+ *threadlocal* is similar to *plain* except that it adds support
+ for a thread-local connection and transaction context, which
+ allows a group of engine operations to participate using the same
+ connection and transaction without the need for explicit passing
+ of a Connection object.
+ The standard method of specifying the engine is via URL as the
+ first positional argument, to indicate the appropriate database
+ dialect and connection arguments, with additional keyword
+ arguments sent as options to the dialect and resulting Engine.
+
+ The URL is in the form ``://opt1=val1&opt2=val2``, where
+ ```` is a name such as *mysql*, *oracle*, *postgres*, and the
+ options indicate username, password, database, etc. Supported
+ keynames include `username`, `user`, `password`, `pw`, `db`,
+ `database`, `host`, `filename`.
+
+ `**kwargs` represents options to be sent to the Engine itself as
+ well as the components of the Engine, including the Dialect, the
+ ConnectionProvider, and the Pool. A list of common options is as
+ follows:
+
+ pool
+ Defaults to None: an instance of ``sqlalchemy.pool.DBProxy`` or
+ ``sqlalchemy.pool.Pool`` to be used as the underlying source for
+ connections (DBProxy/Pool is described in the previous
+ section). If None, a default DBProxy will be created using the
+ engine's own database module with the given arguments.
+
+ echo
+ Defaults to False: if True, the Engine will log all statements
+ as well as a repr() of their parameter lists to the engines
+ logger, which defaults to ``sys.stdout``. A Engine instances'
+ `echo` data member can be modified at any time to turn logging
+ on and off. If set to the string 'debug', result rows will be
+ printed to the standard output as well.
+
+ logger
+ Defaults to None: a file-like object where logging output can be
+ sent, if `echo` is set to True. This defaults to
+ ``sys.stdout``.
+
+ encoding
+ Defaults to 'utf-8': the encoding to be used when
+ encoding/decoding Unicode strings.
+
+ convert_unicode
+ Defaults to False: true if unicode conversion should be applied
+ to all str types.
+
+ module
+ Defaults to None: used by Oracle and Postgres, this is a
+ reference to a DBAPI2 module to be used instead of the engine's
+ default module. For Postgres, the default is psycopg2, or
+ psycopg1 if 2 cannot be found. For Oracle, its cx_Oracle. For
+ mysql, MySQLdb.
+
+ use_ansi
+ Defaults to True: used only by Oracle; when False, the Oracle
+ driver attempts to support a particular *quirk* of some Oracle
+ databases, that the ``LEFT OUTER JOIN`` SQL syntax is not
+ supported, and the *Oracle join* syntax of using
+ ``(+)=`` must be used in order to achieve a
+ ``LEFT OUTER JOIN``. Its advised that the Oracle database be
+ configured to have full ANSI support instead of using this
+ feature.
"""
+
strategy = kwargs.pop('strategy', default_strategy)
strategy = strategies.strategies[strategy]
return strategy.create(*args, **kwargs)
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 1985bcec1f..10001e8a34 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1,257 +1,431 @@
from sqlalchemy import exceptions, sql, schema, util, types, logging
import StringIO, sys, re
+
class ConnectionProvider(object):
- """defines an interface that returns raw Connection objects (or compatible)."""
+ """Define an interface that returns raw Connection objects (or compatible)."""
+
def get_connection(self):
- """this method should return a Connection or compatible object from a DBAPI which
- also contains a close() method.
- It is not defined what context this connection belongs to. It may be newly connected,
- returned from a pool, part of some other kind of context such as thread-local,
- or can be a fixed member of this object."""
+ """Return a Connection or compatible object from a DBAPI which also contains a close() method.
+
+ It is not defined what context this connection belongs to. It
+ may be newly connected, returned from a pool, part of some
+ other kind of context such as thread-local, or can be a fixed
+ member of this object.
+ """
+
raise NotImplementedError()
+
def dispose(self):
- """releases all resources corresponding to this ConnectionProvider, such
- as any underlying connection pools."""
+ """Release all resources corresponding to this ConnectionProvider.
+
+ This includes any underlying connection pools.
+ """
+
raise NotImplementedError()
+
class Dialect(sql.AbstractDialect):
- """Defines the behavior of a specific database/DBAPI.
+ """Define the behavior of a specific database/DBAPI.
+
+ Any aspect of metadata definition, SQL query generation, execution,
+ result-set handling, or anything else which varies between
+ databases is defined under the general category of the Dialect.
+ The Dialect acts as a factory for other database-specific object
+ implementations including ExecutionContext, Compiled,
+ DefaultGenerator, and TypeEngine.
- Any aspect of metadata defintion, SQL query generation, execution, result-set handling,
- or anything else which varies between databases is defined under the general category of
- the Dialect. The Dialect acts as a factory for other database-specific object implementations
- including ExecutionContext, Compiled, DefaultGenerator, and TypeEngine.
-
All Dialects implement the following attributes:
- positional - True if the paramstyle for this Dialect is positional
+ positional
+ True if the paramstyle for this Dialect is positional
- paramstyle - the paramstyle to be used (some DBAPIs support multiple paramstyles)
+ paramstyle
+ The paramstyle to be used (some DBAPIs support multiple paramstyles)
- supports_autoclose_results - usually True; if False, indicates that rows returned by fetchone()
- might not be just plain tuples, and may be "live" proxy objects which still require the cursor
- to be open in order to be read (such as pyPgSQL which has active filehandles for BLOBs). in that
- case, an auto-closing ResultProxy cannot automatically close itself after results are consumed.
+ supports_autoclose_results
+ Usually True; if False, indicates that rows returned by
+ fetchone() might not be just plain tuples, and may be
+ "live" proxy objects which still require the cursor to be open
+ in order to be read (such as pyPgSQL which has active
+ filehandles for BLOBs). In that case, an auto-closing
+ ResultProxy cannot automatically close itself after results are
+ consumed.
- convert_unicode - True if unicode conversion should be applied to all str types
+ convert_unicode
+ True if unicode conversion should be applied to all str types
- encoding - type of encoding to use for unicode, usually defaults to 'utf-8'
+ encoding
+ type of encoding to use for unicode, usually defaults to 'utf-8'
"""
+
def create_connect_args(self, opts):
- """given a dictionary of key-valued connect parameters, returns a tuple
- consisting of a *args/**kwargs suitable to send directly to the dbapi's connect function.
- The connect args will have any number of the following keynames: host, hostname, database, dbanme,
- user,username, password, pw, passwd, filename."""
+ """Build DBAPI compatible connection arguments.
+
+ Given a dictionary of key-valued connect parameters, returns a
+ tuple consisting of a `*args`/`**kwargs` suitable to send directly
+ to the dbapi's connect function. The connect args will have
+ any number of the following keynames: host, hostname,
+ database, dbname, user, username, password, pw, passwd,
+ filename.
+ """
+
raise NotImplementedError()
+
def convert_compiled_params(self, parameters):
- """given a sql.ClauseParameters object, returns an array or dictionary suitable to pass
- directly to this Dialect's DBAPI's execute method."""
+ """Build DBAPI execute arguments from a ClauseParameters.
+
+ Given a sql.ClauseParameters object, returns an array or
+ dictionary suitable to pass directly to this Dialect's DBAPI's
+ execute method.
+ """
+
+ raise NotImplementedError()
+
def type_descriptor(self, typeobj):
- """provides a database-specific TypeEngine object, given the generic object
- which comes from the types module. Subclasses will usually use the adapt_type()
- method in the types module to make this job easy."""
+ """Trasform the type from generic to database-specific.
+
+ Provides a database-specific TypeEngine object, given the
+ generic object which comes from the types module. Subclasses
+ will usually use the adapt_type() method in the types module
+ to make this job easy.
+ """
+
raise NotImplementedError()
+
def oid_column_name(self, column):
- """return the oid column name for this dialect, or None if the dialect cant/wont support OID/ROWID.
-
- the Column instance which represents OID for the query being compiled is passed, so that the dialect
- can inspect the column and its parent selectable to determine if OID/ROWID is not selected for a particular
- selectable (i.e. oracle doesnt support ROWID for UNION, GROUP BY, DISTINCT, etc.)
+ """Return the oid column name for this dialect, or None if the dialect cant/wont support OID/ROWID.
+
+ The Column instance which represents OID for the query being
+ compiled is passed, so that the dialect can inspect the column
+ and its parent selectable to determine if OID/ROWID is not
+ selected for a particular selectable (i.e. oracle doesnt
+ support ROWID for UNION, GROUP BY, DISTINCT, etc.)
"""
+
raise NotImplementedError()
+
def supports_sane_rowcount(self):
- """Provided to indicate when MySQL is being used, which does not have standard behavior
- for the "rowcount" function on a statement handle. """
+ """Indicate whether the dialect properly implements statements rowcount.
+
+ Provided to indicate when MySQL is being used, which does not
+ have standard behavior for the "rowcount" function on a statement handle.
+ """
+
raise NotImplementedError()
+
def schemagenerator(self, engine, proxy, **params):
- """returns a schema.SchemaVisitor instance that can generate schemas, when it is
- invoked to traverse a set of schema objects.
+ """Return a ``schema.SchemaVisitor`` instance that can generate schemas.
- schemagenerator is called via the create() method on Table, Index, and others.
+ `schemagenerator()` is called via the `create()` method on Table,
+ Index, and others.
"""
+
raise NotImplementedError()
+
def schemadropper(self, engine, proxy, **params):
- """returns a schema.SchemaVisitor instance that can drop schemas, when it is
- invoked to traverse a set of schema objects.
+ """Return a ``schema.SchemaVisitor`` instance that can drop schemas.
- schemagenerator is called via the drop() method on Table, Index, and others.
+ `schemadropper()` is called via the `drop()` method on Table,
+ Index, and others.
"""
+
raise NotImplementedError()
+
def defaultrunner(self, engine, proxy, **params):
- """returns a schema.SchemaVisitor instances that can execute defaults."""
+ """Return a ``schema.SchemaVisitor`` instance that can execute defaults."""
+
raise NotImplementedError()
+
def compiler(self, statement, parameters):
- """returns a sql.ClauseVisitor which will produce a string representation of the given
- ClauseElement and parameter dictionary. This object is usually a subclass of
- ansisql.ANSICompiler.
+ """Return a ``sql.ClauseVisitor`` able to transform a ``ClauseElement`` into a string.
+
+ The returned object is usually a subclass of
+ ansisql.ANSICompiler, and will produce a string representation
+ of the given ClauseElement and `parameters` dictionary.
+
+ `compiler()` is called within the context of the compile() method.
+ """
- compiler is called within the context of the compile() method."""
raise NotImplementedError()
+
def reflecttable(self, connection, table):
- """given an Connection and a Table object, reflects its columns and properties from the database."""
+ """Load table description from the database.
+
+ Given a ``Connection`` and a ``Table`` object, reflect its
+ columns and properties from the database.
+ """
+
raise NotImplementedError()
+
def has_table(self, connection, table_name, schema=None):
+ """Check the existence of a particular table in the database.
+
+ Given a ``Connection`` object and a `table_name`, return True
+ if the given table (possibly within the specified `schema`)
+ exists in the database, False otherwise.
+ """
+
raise NotImplementedError()
+
def has_sequence(self, connection, sequence_name):
+ """Check the existence of a particular sequence in the database.
+
+ Given a ``Connection`` object and a `sequence_name`, return
+ True if the given sequence exists in the database, False
+ otherwise.
+ """
+
raise NotImplementedError()
+
def dbapi(self):
- """subclasses override this method to provide the DBAPI module used to establish
- connections."""
+ """Establish a connection to the database.
+
+ Subclasses override this method to provide the DBAPI module
+ used to establish connections.
+ """
+
raise NotImplementedError()
+
def get_default_schema_name(self, connection):
- """returns the currently selected schema given an connection"""
+ """Return the currently selected schema given a connection"""
+
raise NotImplementedError()
+
def execution_context(self):
- """returns a new ExecutionContext object."""
+ """Return a new ExecutionContext object."""
+
raise NotImplementedError()
+
def do_begin(self, connection):
- """provides an implementation of connection.begin()"""
+ """Provide an implementation of connection.begin()."""
+
raise NotImplementedError()
+
def do_rollback(self, connection):
- """provides an implementation of connection.rollback()"""
+ """Provide an implementation of connection.rollback()."""
+
raise NotImplementedError()
+
def do_commit(self, connection):
- """provides an implementation of connection.commit()"""
+ """Provide an implementation of connection.commit()"""
+
raise NotImplementedError()
+
def do_executemany(self, cursor, statement, parameters):
+ """Execute a single SQL statement looping over a sequence of parameters."""
+
raise NotImplementedError()
+
def do_execute(self, cursor, statement, parameters):
+ """Execute a single SQL statement with given parameters."""
+
raise NotImplementedError()
+
def create_cursor(self, connection):
- """return a new cursor generated from the given connection"""
+ """Return a new cursor generated from the given connection."""
+
raise NotImplementedError()
+
def create_result_proxy_args(self, connection, cursor):
- """returns a dictionary of arguments that should be passed to ResultProxy()."""
+ """Return a dictionary of arguments that should be passed to ResultProxy()."""
+
raise NotImplementedError()
+
def compile(self, clauseelement, parameters=None):
- """compile the given ClauseElement using this Dialect.
-
- a convenience method which simply flips around the compile() call
- on ClauseElement."""
+ """Compile the given ClauseElement using this Dialect.
+
+ A convenience method which simply flips around the compile()
+ call on ClauseElement.
+ """
+
return clauseelement.compile(dialect=self, parameters=parameters)
-
+
+
class ExecutionContext(object):
- """a messenger object for a Dialect that corresponds to a single execution. The Dialect
- should provide an ExecutionContext via the create_execution_context() method.
- The pre_exec and post_exec methods will be called for compiled statements, afterwhich
- it is expected that the various methods last_inserted_ids, last_inserted_params, etc.
- will contain appropriate values, if applicable."""
+ """A messenger object for a Dialect that corresponds to a single execution.
+
+ The Dialect should provide an ExecutionContext via the
+ create_execution_context() method. The `pre_exec` and `post_exec`
+ methods will be called for compiled statements, afterwhich it is
+ expected that the various methods `last_inserted_ids`,
+ `last_inserted_params`, etc. will contain appropriate values, if
+ applicable.
+ """
+
def pre_exec(self, engine, proxy, compiled, parameters):
- """called before an execution of a compiled statement. proxy is a callable that
- takes a string statement and a bind parameter list/dictionary."""
+ """Called before an execution of a compiled statement.
+
+ `proxy` is a callable that takes a string statement and a bind
+ parameter list/dictionary.
+ """
+
raise NotImplementedError()
+
def post_exec(self, engine, proxy, compiled, parameters):
- """called after the execution of a compiled statement. proxy is a callable that
- takes a string statement and a bind parameter list/dictionary."""
+ """Called after the execution of a compiled statement.
+
+ `proxy` is a callable that takes a string statement and a bind
+ parameter list/dictionary.
+ """
+
raise NotImplementedError()
+
def get_rowcount(self, cursor):
- """returns the count of rows updated/deleted for an UPDATE/DELETE statement"""
+ """Return the count of rows updated/deleted for an UPDATE/DELETE statement."""
+
raise NotImplementedError()
+
def supports_sane_rowcount(self):
- """Indicates if the "rowcount" DBAPI cursor function works properly.
-
- Currently, MySQLDB does not properly implement this function."""
+ """Indicate if the "rowcount" DBAPI cursor function works properly.
+
+ Currently, MySQLDB does not properly implement this function.
+ """
+
raise NotImplementedError()
+
def last_inserted_ids(self):
- """return the list of the primary key values for the last insert statement executed.
-
- This does not apply to straight textual clauses; only to sql.Insert objects compiled against
- a schema.Table object, which are executed via statement.execute(). The order of items in the
- list is the same as that of the Table's 'primary_key' attribute.
-
- In some cases, this method may invoke a query back to the database to retrieve the data, based on
- the "lastrowid" value in the cursor."""
+ """Return the list of the primary key values for the last insert statement executed.
+
+ This does not apply to straight textual clauses; only to
+ ``sql.Insert`` objects compiled against a ``schema.Table`` object,
+ which are executed via `statement.execute()`. The order of
+ items in the list is the same as that of the Table's
+ 'primary_key' attribute.
+
+ In some cases, this method may invoke a query back to the
+ database to retrieve the data, based on the "lastrowid" value
+ in the cursor.
+ """
+
raise NotImplementedError()
+
def last_inserted_params(self):
- """return a dictionary of the full parameter dictionary for the last compiled INSERT statement.
-
- Includes any ColumnDefaults or Sequences that were pre-executed."""
+ """Return a dictionary of the full parameter dictionary for the last compiled INSERT statement.
+
+ Includes any ColumnDefaults or Sequences that were pre-executed.
+ """
+
raise NotImplementedError()
+
def last_updated_params(self):
- """return a dictionary of the full parameter dictionary for the last compiled UPDATE statement.
-
- Includes any ColumnDefaults that were pre-executed."""
+ """Return a dictionary of the full parameter dictionary for the last compiled UPDATE statement.
+
+ Includes any ColumnDefaults that were pre-executed.
+ """
+
raise NotImplementedError()
+
def lastrow_has_defaults(self):
- """return True if the last row INSERTED via a compiled insert statement contained PassiveDefaults.
-
- The presence of PassiveDefaults indicates that the database inserted data beyond that which we
- passed to the query programmatically."""
+ """Return True if the last row INSERTED via a compiled insert statement contained PassiveDefaults.
+
+ The presence of PassiveDefaults indicates that the database
+ inserted data beyond that which we passed to the query
+ programmatically.
+ """
+
raise NotImplementedError()
+
class Connectable(object):
- """interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
+ """Interface for an object that can provide an Engine and a Connection object which correponds to that Engine."""
+
def contextual_connect(self):
- """returns a Connection object which may be part of an ongoing context."""
+ """Return a Connection object which may be part of an ongoing context."""
+
raise NotImplementedError()
+
def create(self, entity, **kwargs):
- """creates a table or index given an appropriate schema object."""
+ """Create a table or index given an appropriate schema object."""
+
raise NotImplementedError()
+
def drop(self, entity, **kwargs):
+ """Drop a table or index given an appropriate schema object."""
+
raise NotImplementedError()
+
def execute(self, object, *multiparams, **params):
raise NotImplementedError()
+
def _not_impl(self):
raise NotImplementedError()
- engine = property(_not_impl, doc="returns the Engine which this Connectable is associated with.")
+
+ engine = property(_not_impl, doc="The Engine which this Connectable is associated with.")
class Connection(Connectable):
- """represents a single DBAPI connection returned from the underlying connection pool. Provides
- execution support for string-based SQL statements as well as ClauseElement, Compiled and DefaultGenerator objects.
- provides a begin method to return Transaction objects.
-
- The Connection object is **not** threadsafe."""
+ """Represent a single DBAPI connection returned from the underlying connection pool.
+
+ Provides execution support for string-based SQL statements as well
+ as ClauseElement, Compiled and DefaultGenerator objects. Provides
+ a begin method to return Transaction objects.
+
+ The Connection object is **not** threadsafe.
+ """
+
def __init__(self, engine, connection=None, close_with_result=False):
self.__engine = engine
self.__connection = connection or engine.raw_connection()
self.__transaction = None
self.__close_with_result = close_with_result
+
def _get_connection(self):
try:
return self.__connection
except AttributeError:
raise exceptions.InvalidRequestError("This Connection is closed")
+
engine = property(lambda s:s.__engine, doc="The Engine with which this Connection is associated (read only)")
connection = property(_get_connection, doc="The underlying DBAPI connection managed by this Connection.")
should_close_with_result = property(lambda s:s.__close_with_result, doc="Indicates if this Connection should be closed when a corresponding ResultProxy is closed; this is essentially an auto-release mode.")
+
def _create_transaction(self, parent):
return Transaction(self, parent)
+
def connect(self):
"""connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly."""
return self
+
def contextual_connect(self, **kwargs):
"""contextual_connect() is implemented to return self so that an incoming Engine or Connection object can be treated similarly."""
return self
+
def begin(self):
if self.__transaction is None:
self.__transaction = self._create_transaction(None)
return self.__transaction
else:
return self._create_transaction(self.__transaction)
+
def in_transaction(self):
return self.__transaction is not None
+
def _begin_impl(self):
self.__engine.logger.info("BEGIN")
self.__engine.dialect.do_begin(self.connection)
+
def _rollback_impl(self):
self.__engine.logger.info("ROLLBACK")
self.__engine.dialect.do_rollback(self.connection)
self.__connection.close_open_cursors()
self.__transaction = None
+
def _commit_impl(self):
self.__engine.logger.info("COMMIT")
self.__engine.dialect.do_commit(self.connection)
self.__transaction = None
+
def _autocommit(self, statement):
- """when no Transaction is present, this is called after executions to provide "autocommit" behavior."""
- # TODO: have the dialect determine if autocommit can be set on the connection directly without this
+ """When no Transaction is present, this is called after executions to provide "autocommit" behavior."""
+ # TODO: have the dialect determine if autocommit can be set on the connection directly without this
# extra step
if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()):
self._commit_impl()
+
def _autorollback(self):
if not self.in_transaction():
self._rollback_impl()
+
def close(self):
try:
c = self.__connection
@@ -260,12 +434,16 @@ class Connection(Connectable):
self.__connection.close()
self.__connection = None
del self.__connection
+
def scalar(self, object, *multiparams, **params):
return self.execute(object, *multiparams, **params).scalar()
+
def execute(self, object, *multiparams, **params):
return Connection.executors[type(object).__mro__[-2]](self, object, *multiparams, **params)
+
def execute_default(self, default, **kwargs):
return default.accept_schema_visitor(self.__engine.dialect.defaultrunner(self.__engine, self.proxy, **kwargs))
+
def execute_text(self, statement, *multiparams, **params):
if len(multiparams) == 0:
parameters = params
@@ -276,6 +454,7 @@ class Connection(Connectable):
cursor = self._execute_raw(statement, parameters)
rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
return ResultProxy(self.__engine, self, cursor, **rpargs)
+
def _params_to_listofdicts(self, *multiparams, **params):
if len(multiparams) == 0:
return [params]
@@ -288,6 +467,7 @@ class Connection(Connectable):
return [multiparams[0]]
else:
return multiparams
+
def execute_clauseelement(self, elem, *multiparams, **params):
executemany = len(multiparams) > 0
if executemany:
@@ -295,8 +475,10 @@ class Connection(Connectable):
else:
param = params
return self.execute_compiled(elem.compile(engine=self.__engine, parameters=param), *multiparams, **params)
+
def execute_compiled(self, compiled, *multiparams, **params):
- """executes a sql.Compiled object."""
+ """Execute a sql.Compiled object."""
+
if not compiled.can_execute:
raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled)))
cursor = self.__engine.dialect.create_cursor(self.connection)
@@ -316,7 +498,7 @@ class Connection(Connectable):
context.post_exec(self.__engine, proxy, compiled, parameters)
rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, columns=compiled.columns, **rpargs)
-
+
# poor man's multimethod/generic function thingy
executors = {
sql.ClauseElement : execute_clauseelement,
@@ -324,20 +506,28 @@ class Connection(Connectable):
schema.SchemaItem:execute_default,
str.__mro__[-2] : execute_text
}
-
+
def create(self, entity, **kwargs):
- """creates a table or index given an appropriate schema object."""
+ """Create a table or index given an appropriate schema object."""
+
return self.__engine.create(entity, connection=self, **kwargs)
+
def drop(self, entity, **kwargs):
- """drops a table or index given an appropriate schema object."""
+ """Drop a table or index given an appropriate schema object."""
+
return self.__engine.drop(entity, connection=self, **kwargs)
+
def reflecttable(self, table, **kwargs):
- """reflects the columns in the given table from the database."""
+ """Reflect the columns in the given table from the database."""
+
return self.__engine.reflecttable(table, connection=self, **kwargs)
+
def default_schema_name(self):
return self.__engine.dialect.get_default_schema_name(self)
+
def run_callable(self, callable_):
return callable_(self)
+
def _execute_raw(self, statement, parameters=None, cursor=None, context=None, **kwargs):
if cursor is None:
cursor = self.__engine.dialect.create_cursor(self.connection)
@@ -367,6 +557,7 @@ class Connection(Connectable):
if self.__close_with_result:
self.close()
raise exceptions.SQLError(statement, parameters, e)
+
def _executemany(self, c, statement, parameters, context=None):
try:
self.__engine.dialect.do_executemany(c, statement, parameters, context=context)
@@ -376,27 +567,36 @@ class Connection(Connectable):
if self.__close_with_result:
self.close()
raise exceptions.SQLError(statement, parameters, e)
+
def proxy(self, statement=None, parameters=None):
- """executes the given statement string and parameter object.
- the parameter object is expected to be the result of a call to compiled.get_params().
- This callable is a generic version of a connection/cursor-specific callable that
- is produced within the execute_compiled method, and is used for objects that require
- this style of proxy when outside of an execute_compiled method, primarily the DefaultRunner."""
+ """Execute the given statement string and parameter object.
+
+ The parameter object is expected to be the result of a call to
+ ``compiled.get_params()``. This callable is a generic version
+ of a connection/cursor-specific callable that is produced
+ within the execute_compiled method, and is used for objects
+ that require this style of proxy when outside of an
+ execute_compiled method, primarily the DefaultRunner.
+ """
parameters = self.__engine.dialect.convert_compiled_params(parameters)
return self._execute_raw(statement, parameters)
class Transaction(object):
- """represents a Transaction in progress.
-
- the Transaction object is **not** threadsafe."""
+ """Represent a Transaction in progress.
+
+ The Transaction object is **not** threadsafe.
+ """
+
def __init__(self, connection, parent):
self.__connection = connection
self.__parent = parent or self
self.__is_active = True
if self.__parent is self:
self.__connection._begin_impl()
+
connection = property(lambda s:s.__connection, doc="The Connection object referenced by this Transaction")
is_active = property(lambda s:s.__is_active)
+
def rollback(self):
if not self.__parent.__is_active:
return
@@ -405,6 +605,7 @@ class Transaction(object):
self.__is_active = False
else:
self.__parent.rollback()
+
def commit(self):
if not self.__parent.__is_active:
raise exceptions.InvalidRequestError("This transaction is inactive")
@@ -414,9 +615,10 @@ class Transaction(object):
class Engine(sql.Executor, Connectable):
"""
- Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
+ Connects a ConnectionProvider, a Dialect and a CompilerFactory together to
provide a default implementation of SchemaEngine.
"""
+
def __init__(self, connection_provider, dialect, echo=None):
self.connection_provider = connection_provider
self.dialect=dialect
@@ -426,27 +628,35 @@ class Engine(sql.Executor, Connectable):
name = property(lambda s:sys.modules[s.dialect.__module__].descriptor()['name'])
engine = property(lambda s:s)
echo = logging.echo_property()
-
+
def dispose(self):
self.connection_provider.dispose()
+
def create(self, entity, connection=None, **kwargs):
- """creates a table or index within this engine's database connection given a schema.Table object."""
+ """Create a table or index within this engine's database connection given a schema.Table object."""
+
self._run_visitor(self.dialect.schemagenerator, entity, connection=connection, **kwargs)
+
def drop(self, entity, connection=None, **kwargs):
- """drops a table or index within this engine's database connection given a schema.Table object."""
+ """Drop a table or index within this engine's database connection given a schema.Table object."""
+
self._run_visitor(self.dialect.schemadropper, entity, connection=connection, **kwargs)
+
def execute_default(self, default, **kwargs):
connection = self.contextual_connect()
try:
return connection.execute_default(default, **kwargs)
finally:
connection.close()
-
+
def _func(self):
return sql._FunctionGenerator(self)
+
func = property(_func)
+
def text(self, text, *args, **kwargs):
- """returns a sql.text() object for performing literal queries."""
+ """Return a sql.text() object for performing literal queries."""
+
return sql.text(text, engine=self, *args, **kwargs)
def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
@@ -459,12 +669,16 @@ class Engine(sql.Executor, Connectable):
finally:
if connection is None:
conn.close()
-
+
def transaction(self, callable_, connection=None, *args, **kwargs):
- """executes the given function within a transaction boundary. this is a shortcut for
- explicitly calling begin() and commit() and optionally rollback() when execptions are raised.
- The given *args and **kwargs will be passed to the function, as well as the Connection used
- in the transaction."""
+ """Execute the given function within a transaction boundary.
+
+ This is a shortcut for explicitly calling `begin()` and `commit()`
+ and optionally `rollback()` when exceptions are raised. The
+ given `*args` and `**kwargs` will be passed to the function, as
+ well as the Connection used in the transaction.
+ """
+
if connection is None:
conn = self.contextual_connect()
else:
@@ -481,7 +695,7 @@ class Engine(sql.Executor, Connectable):
finally:
if connection is None:
conn.close()
-
+
def run_callable(self, callable_, connection=None, *args, **kwargs):
if connection is None:
conn = self.contextual_connect()
@@ -492,32 +706,37 @@ class Engine(sql.Executor, Connectable):
finally:
if connection is None:
conn.close()
-
+
def execute(self, statement, *multiparams, **params):
connection = self.contextual_connect(close_with_result=True)
return connection.execute(statement, *multiparams, **params)
def scalar(self, statement, *multiparams, **params):
return self.execute(statement, *multiparams, **params).scalar()
-
+
def execute_compiled(self, compiled, *multiparams, **params):
connection = self.contextual_connect(close_with_result=True)
return connection.execute_compiled(compiled, *multiparams, **params)
-
+
def compiler(self, statement, parameters, **kwargs):
return self.dialect.compiler(statement, parameters, engine=self, **kwargs)
def connect(self, **kwargs):
- """returns a newly allocated Connection object."""
+ """Return a newly allocated Connection object."""
+
return Connection(self, **kwargs)
-
+
def contextual_connect(self, close_with_result=False, **kwargs):
- """returns a Connection object which may be newly allocated, or may be part of some
- ongoing context. This Connection is meant to be used by the various "auto-connecting" operations."""
+ """Return a Connection object which may be newly allocated, or may be part of some ongoing context.
+
+ This Connection is meant to be used by the various "auto-connecting" operations.
+ """
+
return Connection(self, close_with_result=close_with_result, **kwargs)
-
+
def reflecttable(self, table, connection=None):
- """given a Table object, reflects its columns and properties from the database."""
+ """Given a Table object, reflects its columns and properties from the database."""
+
if connection is None:
conn = self.contextual_connect()
else:
@@ -527,34 +746,42 @@ class Engine(sql.Executor, Connectable):
finally:
if connection is None:
conn.close()
+
def has_table(self, table_name, schema=None):
return self.run_callable(lambda c: self.dialect.has_table(c, table_name, schema=schema))
-
+
def raw_connection(self):
- """returns a DBAPI connection."""
+ """Return a DBAPI connection."""
+
return self.connection_provider.get_connection()
def log(self, msg):
- """logs a message using this SQLEngine's logger stream."""
+ """Log a message using this SQLEngine's logger stream."""
+
self.logger.info(msg)
class ResultProxy(object):
- """wraps a DBAPI cursor object to provide access to row columns based on integer
- position, case-insensitive column name, or by schema.Column object. e.g.:
-
- row = fetchone()
+ """Wraps a DBAPI cursor object to provide easier access to row columns.
+
+ Individual columns may be accessed by their integer position,
+ case-insensitive column name, or by ``schema.Column``
+ object. e.g.::
+
+ row = fetchone()
- col1 = row[0] # access via integer position
+ col1 = row[0] # access via integer position
- col2 = row['col2'] # access via name
+ col2 = row['col2'] # access via name
- col3 = row[mytable.c.mycol] # access via Column object.
-
- ResultProxy also contains a map of TypeEngine objects and will invoke the appropriate
- convert_result_value() method before returning columns, as well as the ExecutionContext
- corresponding to the statement execution. It provides several methods for which
+ col3 = row[mytable.c.mycol] # access via Column object.
+
+ ResultProxy also contains a map of TypeEngine objects and will
+ invoke the appropriate ``convert_result_value()` method before
+ returning columns, as well as the ExecutionContext corresponding
+ to the statement execution. It provides several methods for which
to obtain information from the underlying ExecutionContext.
"""
+
class AmbiguousColumn(object):
def __init__(self, key):
self.key = key
@@ -562,15 +789,16 @@ class ResultProxy(object):
return self
def convert_result_value(self, arg, engine):
raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
-
+
def __new__(cls, *args, **kwargs):
if cls is ResultProxy and kwargs.has_key('should_prefetch') and kwargs['should_prefetch']:
return PrefetchingResultProxy(*args, **kwargs)
else:
return object.__new__(cls, *args, **kwargs)
-
+
def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, columns=None, should_prefetch=None):
"""ResultProxy objects are constructed via the execute() method on SQLEngine."""
+
self.connection = connection
self.dialect = engine.dialect
self.cursor = cursor
@@ -603,20 +831,25 @@ class ResultProxy(object):
self.keys.append(colname)
self.props[i] = rec
i+=1
+
def _executioncontext(self):
try:
return self.__executioncontext
except AttributeError:
raise exceptions.InvalidRequestError("This ResultProxy does not have an execution context with which to complete this operation. Execution contexts are not generated for literal SQL execution.")
executioncontext = property(_executioncontext)
-
+
def close(self):
- """close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
-
- If this ResultProxy was generated from an implicit execution, the underlying Connection will
- also be closed (returns the underlying DBAPI connection to the connection pool.)
-
- This method is also called automatically when all result rows are exhausted."""
+ """Close this ResultProxy, and the underlying DBAPI cursor corresponding to the execution.
+
+ If this ResultProxy was generated from an implicit execution,
+ the underlying Connection will also be closed (returns the
+ underlying DBAPI connection to the connection pool.)
+
+ This method is also called automatically when all result rows
+ are exhausted.
+ """
+
if not self.closed:
self.closed = True
self.cursor.close()
@@ -624,8 +857,13 @@ class ResultProxy(object):
self.connection.close()
def _convert_key(self, key):
- """given a key, which could be a ColumnElement, string, etc., matches it to the
- appropriate key we got from the result set's metadata; then cache it locally for quick re-access."""
+ """Convert and cache a key.
+
+ Given a key, which could be a ColumnElement, string, etc.,
+ matches it to the appropriate key we got from the result set's
+ metadata; then cache it locally for quick re-access.
+ """
+
try:
return self.__key_cache[key]
except KeyError:
@@ -659,18 +897,18 @@ class ResultProxy(object):
raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % str(key))
self.__key_cache[key] = rec
return rec
-
+
def _has_key(self, row, key):
try:
self._convert_key(key)
return True
except KeyError:
return False
-
+
def _get_col(self, row, key):
rec = self._convert_key(key)
return rec[0].dialect_impl(self.dialect).convert_result_value(row[rec[1]], self.dialect)
-
+
def __iter__(self):
while True:
row = self.fetchone()
@@ -678,43 +916,59 @@ class ResultProxy(object):
raise StopIteration
else:
yield row
-
+
def last_inserted_ids(self):
- """return last_inserted_ids() from the underlying ExecutionContext.
-
- See ExecutionContext for details."""
+ """Return ``last_inserted_ids()`` from the underlying ExecutionContext.
+
+ See ExecutionContext for details.
+ """
+
return self.executioncontext.last_inserted_ids()
+
def last_updated_params(self):
- """return last_updated_params() from the underlying ExecutionContext.
-
- See ExecutionContext for details."""
+ """Return ``last_updated_params()`` from the underlying ExecutionContext.
+
+ See ExecutionContext for details.
+ """
+
return self.executioncontext.last_updated_params()
+
def last_inserted_params(self):
- """return last_inserted_params() from the underlying ExecutionContext.
-
- See ExecutionContext for details."""
+ """Return ``last_inserted_params()`` from the underlying ExecutionContext.
+
+ See ExecutionContext for details.
+ """
+
return self.executioncontext.last_inserted_params()
+
def lastrow_has_defaults(self):
- """return lastrow_has_defaults() from the underlying ExecutionContext.
-
- See ExecutionContext for details."""
+ """Return ``lastrow_has_defaults()`` from the underlying ExecutionContext.
+
+ See ExecutionContext for details.
+ """
+
return self.executioncontext.lastrow_has_defaults()
+
def supports_sane_rowcount(self):
- """return supports_sane_rowcount() from the underlying ExecutionContext.
-
- See ExecutionContext for details."""
+ """Return ``supports_sane_rowcount()`` from the underlying ExecutionContext.
+
+ See ExecutionContext for details.
+ """
+
return self.executioncontext.supports_sane_rowcount()
-
+
def fetchall(self):
- """fetch all rows, just like DBAPI cursor.fetchall()."""
+ """Fetch all rows, just like DBAPI ``cursor.fetchall()``."""
+
l = []
for row in self.cursor.fetchall():
l.append(RowProxy(self, row))
self.close()
return l
-
+
def fetchmany(self, size=None):
- """fetch many rows, juts like DBAPI cursor.fetchmany(size=cursor.arraysize)"""
+ """Fetch many rows, just like DBAPI ``cursor.fetchmany(size=cursor.arraysize)``."""
+
if size is None:
rows = self.cursor.fetchmany()
else:
@@ -725,9 +979,10 @@ class ResultProxy(object):
if len(l) == 0:
self.close()
return l
-
+
def fetchone(self):
- """fetch one row, just like DBAPI cursor.fetchone()."""
+ """Fetch one row, just like DBAPI ``cursor.fetchone()``."""
+
row = self.cursor.fetchone()
if row is not None:
return RowProxy(self, row)
@@ -736,7 +991,8 @@ class ResultProxy(object):
return None
def scalar(self):
- """fetch the first column of the first row, and close the result set."""
+ """Fetch the first column of the first row, and close the result set."""
+
row = self.cursor.fetchone()
try:
if row is not None:
@@ -745,16 +1001,17 @@ class ResultProxy(object):
return None
finally:
self.close()
-
+
class PrefetchingResultProxy(ResultProxy):
"""ResultProxy that loads all columns into memory each time fetchone() is
called. If fetchmany() or fetchall() are called, the full grid of results
is fetched.
"""
+
def _get_col(self, row, key):
rec = self._convert_key(key)
return row[rec[1]]
-
+
def fetchall(self):
l = []
while True:
@@ -764,7 +1021,7 @@ class PrefetchingResultProxy(ResultProxy):
else:
break
return l
-
+
def fetchmany(self, size=None):
if size is None:
return self.fetchall()
@@ -776,7 +1033,7 @@ class PrefetchingResultProxy(ResultProxy):
else:
break
return l
-
+
def fetchone(self):
sup = super(PrefetchingResultProxy, self)
row = self.cursor.fetchone()
@@ -786,81 +1043,114 @@ class PrefetchingResultProxy(ResultProxy):
else:
self.close()
return None
-
+
class RowProxy(object):
- """proxies a single cursor row for a parent ResultProxy. Mostly follows
- "ordered dictionary" behavior, mapping result values to the string-based column name,
- the integer position of the result in the row, as well as Column instances which
- can be mapped to the original Columns that produced this result set (for results
- that correspond to constructed SQL expressions)."""
+ """Proxie a single cursor row for a parent ResultProxy.
+
+ Mostly follows "ordered dictionary" behavior, mapping result
+ values to the string-based column name, the integer position of
+ the result in the row, as well as Column instances which can be
+ mapped to the original Columns that produced this result set (for
+ results that correspond to constructed SQL expressions).
+ """
+
def __init__(self, parent, row):
"""RowProxy objects are constructed by ResultProxy objects."""
+
self.__parent = parent
self.__row = row
if self.__parent._ResultProxy__echo:
self.__parent.engine.logger.debug("Row " + repr(row))
+
def close(self):
- """close the parent ResultProxy."""
+ """Close the parent ResultProxy."""
+
self.__parent.close()
+
def __iter__(self):
for i in range(0, len(self.__row)):
yield self.__parent._get_col(self.__row, i)
+
def __eq__(self, other):
return (other is self) or (other == tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))]))
+
def __repr__(self):
return repr(tuple([self.__parent._get_col(self.__row, key) for key in range(0, len(self.__row))]))
+
def has_key(self, key):
- """return True if this RowProxy contains the given key."""
+ """Return True if this RowProxy contains the given key."""
+
return self.__parent._has_key(self.__row, key)
+
def __getitem__(self, key):
return self.__parent._get_col(self.__row, key)
+
def __getattr__(self, name):
try:
return self.__parent._get_col(self.__row, name)
except KeyError, e:
raise AttributeError(e.args[0])
+
def items(self):
- """return a list of tuples, each tuple containing a key/value pair."""
+ """Return a list of tuples, each tuple containing a key/value pair."""
+
return [(key, getattr(self, key)) for key in self.keys()]
+
def keys(self):
- """return the list of keys as strings represented by this RowProxy."""
+ """Return the list of keys as strings represented by this RowProxy."""
+
return self.__parent.keys
+
def values(self):
- """return the values represented by this RowProxy as a list."""
+ """Return the values represented by this RowProxy as a list."""
+
return list(self)
- def __len__(self):
+
+ def __len__(self):
return len(self.__row)
class SchemaIterator(schema.SchemaVisitor):
- """a visitor that can gather text into a buffer and execute the contents of the buffer."""
+ """A visitor that can gather text into a buffer and execute the contents of the buffer."""
+
def __init__(self, engine, proxy, **params):
- """construct a new SchemaIterator.
-
- engine - the Engine used by this SchemaIterator
-
- proxy - a callable which takes a statement and bind parameters and executes it, returning
- the cursor (the actual DBAPI cursor). The callable should use the same cursor repeatedly."""
+ """Construct a new SchemaIterator.
+
+ engine
+ the Engine used by this SchemaIterator
+
+ proxy
+ a callable which takes a statement and bind parameters and
+ executes it, returning the cursor (the actual DBAPI cursor).
+ The callable should use the same cursor repeatedly.
+ """
+
self.proxy = proxy
self.engine = engine
self.buffer = StringIO.StringIO()
def append(self, s):
- """append content to the SchemaIterator's query buffer."""
+ """Append content to the SchemaIterator's query buffer."""
+
self.buffer.write(s)
def execute(self):
- """execute the contents of the SchemaIterator's buffer."""
+ """Execute the contents of the SchemaIterator's buffer."""
+
try:
return self.proxy(self.buffer.getvalue(), None)
finally:
self.buffer.truncate(0)
class DefaultRunner(schema.SchemaVisitor):
- """a visitor which accepts ColumnDefault objects, produces the dialect-specific SQL corresponding
- to their execution, and executes the SQL, returning the result value.
-
- DefaultRunners are used internally by Engines and Dialects. Specific database modules should provide
- their own subclasses of DefaultRunner to allow database-specific behavior."""
+ """A visitor which accepts ColumnDefault objects, produces the
+ dialect-specific SQL corresponding to their execution, and
+ executes the SQL, returning the result value.
+
+ DefaultRunners are used internally by Engines and Dialects.
+ Specific database modules should provide their own subclasses of
+ DefaultRunner to allow database-specific behavior.
+ """
+
def __init__(self, engine, proxy):
self.proxy = proxy
self.engine = engine
@@ -878,12 +1168,20 @@ class DefaultRunner(schema.SchemaVisitor):
return None
def visit_passive_default(self, default):
- """passive defaults by definition return None on the app side,
- and are post-fetched to get the DB-side value"""
+ """Do nothing.
+
+ Passive defaults by definition return None on the app side,
+ and are post-fetched to get the DB-side value.
+ """
+
return None
def visit_sequence(self, seq):
- """sequences are not supported by default"""
+ """Do nothing.
+
+ Sequences are not supported by default.
+ """
+
return None
def exec_default_sql(self, default):
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 40210e88f6..ef0a6cc57b 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -9,21 +9,23 @@ from sqlalchemy import schema, exceptions, util, sql, types
import StringIO, sys, re
from sqlalchemy.engine import base
-"""provides default implementations of the engine interfaces"""
-
+"""Provide default implementations of the engine interfaces"""
class PoolConnectionProvider(base.ConnectionProvider):
def __init__(self, pool):
self._pool = pool
+
def get_connection(self):
return self._pool.connect()
+
def dispose(self):
self._pool.dispose()
if hasattr(self, '_dbproxy'):
self._dbproxy.dispose()
-
+
class DefaultDialect(base.Dialect):
- """default implementation of Dialect"""
+ """Default implementation of Dialect"""
+
def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', **kwargs):
self.convert_unicode = convert_unicode
self.supports_autoclose_results = True
@@ -31,52 +33,75 @@ class DefaultDialect(base.Dialect):
self.positional = False
self._ischema = None
self._figure_paramstyle(default=default_paramstyle)
+
def create_execution_context(self):
return DefaultExecutionContext(self)
+
def type_descriptor(self, typeobj):
- """provides a database-specific TypeEngine object, given the generic object
- which comes from the types module. Subclasses will usually use the adapt_type()
- method in the types module to make this job easy."""
+ """Provide a database-specific ``TypeEngine`` object, given
+ the generic object which comes from the types module.
+
+ Subclasses will usually use the ``adapt_type()`` method in the
+ types module to make this job easy."""
+
if type(typeobj) is type:
typeobj = typeobj()
return typeobj
+
def oid_column_name(self, column):
return None
+
def supports_sane_rowcount(self):
return True
+
def do_begin(self, connection):
- """implementations might want to put logic here for turning autocommit on/off,
- etc."""
+ """Implementations might want to put logic here for turning
+ autocommit on/off, etc.
+ """
+
pass
+
def do_rollback(self, connection):
- """implementations might want to put logic here for turning autocommit on/off,
- etc."""
+ """Implementations might want to put logic here for turning
+ autocommit on/off, etc.
+ """
+
#print "ENGINE ROLLBACK ON ", connection.connection
connection.rollback()
+
def do_commit(self, connection):
- """implementations might want to put logic here for turning autocommit on/off, etc."""
+ """Implementations might want to put logic here for turning
+ autocommit on/off, etc.
+ """
+
#print "ENGINE COMMIT ON ", connection.connection
connection.commit()
+
def do_executemany(self, cursor, statement, parameters, **kwargs):
cursor.executemany(statement, parameters)
+
def do_execute(self, cursor, statement, parameters, **kwargs):
cursor.execute(statement, parameters)
+
def defaultrunner(self, engine, proxy):
return base.DefaultRunner(engine, proxy)
+
def create_cursor(self, connection):
return connection.cursor()
+
def create_result_proxy_args(self, connection, cursor):
return dict(should_prefetch=False)
-
+
def _set_paramstyle(self, style):
self._paramstyle = style
self._figure_paramstyle(style)
+
paramstyle = property(lambda s:s._paramstyle, _set_paramstyle)
def convert_compiled_params(self, parameters):
executemany = parameters is not None and isinstance(parameters, list)
# the bind params are a CompiledParams object. but all the DBAPI's hate
- # that object (or similar). so convert it to a clean
+ # that object (or similar). so convert it to a clean
# dictionary/list/tuple of dictionary/tuple of list
if parameters is not None:
if self.positional:
@@ -125,29 +150,40 @@ class DefaultDialect(base.Dialect):
class DefaultExecutionContext(base.ExecutionContext):
def __init__(self, dialect):
self.dialect = dialect
+
def pre_exec(self, engine, proxy, compiled, parameters):
self._process_defaults(engine, proxy, compiled, parameters)
+
def post_exec(self, engine, proxy, compiled, parameters):
pass
+
def get_rowcount(self, cursor):
if hasattr(self, '_rowcount'):
return self._rowcount
else:
return cursor.rowcount
+
def supports_sane_rowcount(self):
return self.dialect.supports_sane_rowcount()
+
def last_inserted_ids(self):
return self._last_inserted_ids
+
def last_inserted_params(self):
return self._last_inserted_params
+
def last_updated_params(self):
- return self._last_updated_params
+ return self._last_updated_params
+
def lastrow_has_defaults(self):
return self._lastrow_has_defaults
+
def set_input_sizes(self, cursor, parameters):
- """given a cursor and ClauseParameters, call the appropriate style of
- setinputsizes() on the cursor, using DBAPI types from the bind parameter's
- TypeEngine objects."""
+ """Given a cursor and ClauseParameters, call the appropriate
+ style of ``setinputsizes()`` on the cursor, using DBAPI types
+ from the bind parameter's ``TypeEngine`` objects.
+ """
+
if isinstance(parameters, list):
plist = parameters
else:
@@ -166,19 +202,27 @@ class DefaultExecutionContext(base.ExecutionContext):
typeengine = params.binds[key].type
inputsizes[key] = typeengine.get_dbapi_type(self.dialect.module)
cursor.setinputsizes(**inputsizes)
-
+
def _process_defaults(self, engine, proxy, compiled, parameters):
- """INSERT and UPDATE statements, when compiled, may have additional columns added to their
- VALUES and SET lists corresponding to column defaults/onupdates that are present on the
- Table object (i.e. ColumnDefault, Sequence, PassiveDefault). This method pre-execs those
- DefaultGenerator objects that require pre-execution and sets their values within the
- parameter list, and flags the thread-local state about
- PassiveDefault objects that may require post-fetching the row after it is inserted/updated.
- This method relies upon logic within the ANSISQLCompiler in its visit_insert and
- visit_update methods that add the appropriate column clauses to the statement when its
- being compiled, so that these parameters can be bound to the statement."""
+ """``INSERT`` and ``UPDATE`` statements, when compiled, may
+ have additional columns added to their ``VALUES`` and ``SET``
+ lists corresponding to column defaults/onupdates that are
+ present on the ``Table`` object (i.e. ``ColumnDefault``,
+ ``Sequence``, ``PassiveDefault``). This method pre-execs
+ those ``DefaultGenerator`` objects that require pre-execution
+ and sets their values within the parameter list, and flags the
+ thread-local state about ``PassiveDefault`` objects that may
+ require post-fetching the row after it is inserted/updated.
+
+ This method relies upon logic within the ``ANSISQLCompiler``
+ in its `visit_insert` and `visit_update` methods that add the
+ appropriate column clauses to the statement when its being
+ compiled, so that these parameters can be bound to the
+ statement.
+ """
+
if compiled is None: return
-
+
if getattr(compiled, "isinsert", False):
if isinstance(parameters, list):
plist = parameters
@@ -198,9 +242,9 @@ class DefaultExecutionContext(base.ExecutionContext):
if c.primary_key:
need_lastrowid = True
# check if its not present at all. see if theres a default
- # and fire it off, and add to bind parameters. if
+ # and fire it off, and add to bind parameters. if
# its a pk, add the value to our last_inserted_ids list,
- # or, if its a SQL-side default, dont do any of that, but we'll need
+ # or, if its a SQL-side default, dont do any of that, but we'll need
# the SQL-generated value after execution.
elif not param.has_key(c.key) or param[c.key] is None:
if isinstance(c.default, schema.PassiveDefault):
@@ -212,7 +256,7 @@ class DefaultExecutionContext(base.ExecutionContext):
last_inserted_ids.append(param[c.key])
elif c.primary_key:
need_lastrowid = True
- # its an explicitly passed pk value - add it to
+ # its an explicitly passed pk value - add it to
# our last_inserted_ids list.
elif c.primary_key:
last_inserted_ids.append(param[c.key])
@@ -229,7 +273,7 @@ class DefaultExecutionContext(base.ExecutionContext):
drunner = self.dialect.defaultrunner(engine, proxy)
self._lastrow_has_defaults = False
for param in plist:
- # check the "onupdate" status of each column in the table
+ # check the "onupdate" status of each column in the table
for c in compiled.statement.table.c:
# it will be populated by a SQL clause - we'll need that
# after execution.
@@ -242,5 +286,3 @@ class DefaultExecutionContext(base.ExecutionContext):
if value is not None:
param[c.key] = value
self._last_updated_params = param
-
-
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
index a7f2cc0036..7a7b84aa99 100644
--- a/lib/sqlalchemy/engine/strategies.py
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -1,8 +1,13 @@
-"""defines different strategies for creating new instances of sql.Engine.
-by default there are two, one which is the "thread-local" strategy, one which is the "plain" strategy.
-new strategies can be added via constructing a new EngineStrategy object which will add itself to the
-list of available strategies here, or replace one of the existing name.
-this can be accomplished via a mod; see the sqlalchemy/mods package for details."""
+"""Define different strategies for creating new instances of sql.Engine.
+
+By default there are two, one which is the "thread-local" strategy,
+one which is the "plain" strategy.
+
+New strategies can be added via constructing a new EngineStrategy
+object which will add itself to the list of available strategies here,
+or replace one of the existing name. this can be accomplished via a
+mod; see the sqlalchemy/mods package for details.
+"""
from sqlalchemy.engine import base, default, threadlocal, url
@@ -12,22 +17,30 @@ from sqlalchemy import pool as poollib
strategies = {}
class EngineStrategy(object):
- """defines a function that receives input arguments and produces an instance of sql.Engine, typically
- an instance sqlalchemy.engine.base.Engine or a subclass."""
+ """Define a function that receives input arguments and produces an
+ instance of sql.Engine, typically an instance
+ sqlalchemy.engine.base.Engine or a subclass.
+ """
+
def __init__(self, name):
- """construct a new EngineStrategy object and sets it in the list of available strategies
- under this name."""
+ """Construct a new EngineStrategy object.
+
+ Sets it in the list of available strategies under this name.
+ """
+
self.name = name
strategies[self.name] = self
+
def create(self, *args, **kwargs):
- """given arguments, returns a new sql.Engine instance."""
+ """Given arguments, returns a new sql.Engine instance."""
+
raise NotImplementedError()
class DefaultEngineStrategy(EngineStrategy):
def create(self, name_or_url, **kwargs):
# create url.URL object
u = url.make_url(name_or_url)
-
+
# get module from sqlalchemy.databases
module = u.get_module()
@@ -36,7 +49,7 @@ class DefaultEngineStrategy(EngineStrategy):
for k in util.get_cls_kwargs(module.dialect):
if k in kwargs:
dialect_args[k] = kwargs.pop(k)
-
+
# create dialect
dialect = module.dialect(**dialect_args)
@@ -50,6 +63,7 @@ class DefaultEngineStrategy(EngineStrategy):
dbapi = kwargs.pop('module', dialect.dbapi())
if dbapi is None:
raise exceptions.InvalidRequestError("Cant get DBAPI module for dialect '%s'" % dialect)
+
def connect():
try:
return dbapi.connect(*cargs, **cparams)
@@ -80,41 +94,48 @@ class DefaultEngineStrategy(EngineStrategy):
for k in util.get_cls_kwargs(engineclass):
if k in kwargs:
engine_args[k] = kwargs.pop(k)
-
+
# all kwargs should be consumed
if len(kwargs):
raise TypeError("Invalid argument(s) %s sent to create_engine(), using configuration %s/%s/%s. Please check that the keyword arguments are appropriate for this combination of components." % (','.join(["'%s'" % k for k in kwargs]), dialect.__class__.__name__, pool.__class__.__name__, engineclass.__name__))
-
+
return engineclass(provider, dialect, **engine_args)
def pool_threadlocal(self):
raise NotImplementedError()
+
def get_pool_provider(self, pool):
raise NotImplementedError()
+
def get_engine_cls(self):
raise NotImplementedError()
-
+
class PlainEngineStrategy(DefaultEngineStrategy):
def __init__(self):
DefaultEngineStrategy.__init__(self, 'plain')
+
def pool_threadlocal(self):
return False
+
def get_pool_provider(self, pool):
return default.PoolConnectionProvider(pool)
+
def get_engine_cls(self):
return base.Engine
+
PlainEngineStrategy()
class ThreadLocalEngineStrategy(DefaultEngineStrategy):
def __init__(self):
DefaultEngineStrategy.__init__(self, 'threadlocal')
+
def pool_threadlocal(self):
return True
+
def get_pool_provider(self, pool):
return threadlocal.TLocalConnectionProvider(pool)
+
def get_engine_cls(self):
return threadlocal.TLEngine
-ThreadLocalEngineStrategy()
-
-
+ThreadLocalEngineStrategy()
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py
index beac3ee3f9..2bbb1ed43a 100644
--- a/lib/sqlalchemy/engine/threadlocal.py
+++ b/lib/sqlalchemy/engine/threadlocal.py
@@ -2,19 +2,24 @@ from sqlalchemy import schema, exceptions, util, sql, types
import StringIO, sys, re
from sqlalchemy.engine import base, default
-"""provides a thread-local transactional wrapper around the basic ComposedSQLEngine. multiple calls to engine.connect()
-will return the same connection for the same thread. also provides begin/commit methods on the engine itself
-which correspond to a thread-local transaction."""
+"""Provide a thread-local transactional wrapper around the basic ComposedSQLEngine.
+
+Multiple calls to engine.connect() will return the same connection for
+the same thread. also provides begin/commit methods on the engine
+itself which correspond to a thread-local transaction.
+"""
class TLSession(object):
def __init__(self, engine):
self.engine = engine
self.__tcount = 0
+
def get_connection(self, close_with_result=False):
try:
return self.__transaction._increment_connect()
except AttributeError:
return TLConnection(self, close_with_result=close_with_result)
+
def reset(self):
try:
self.__transaction._force_close()
@@ -23,20 +28,24 @@ class TLSession(object):
except AttributeError:
pass
self.__tcount = 0
+
def in_transaction(self):
return self.__tcount > 0
+
def begin(self):
if self.__tcount == 0:
self.__transaction = self.get_connection()
self.__trans = self.__transaction._begin()
self.__tcount += 1
return self.__trans
+
def rollback(self):
if self.__tcount > 0:
try:
self.__trans._rollback_impl()
finally:
self.reset()
+
def commit(self):
if self.__tcount == 1:
try:
@@ -45,6 +54,7 @@ class TLSession(object):
self.reset()
elif self.__tcount > 1:
self.__tcount -= 1
+
def is_begun(self):
return self.__tcount > 0
@@ -53,67 +63,96 @@ class TLConnection(base.Connection):
base.Connection.__init__(self, session.engine, close_with_result=close_with_result)
self.__session = session
self.__opencount = 1
+
session = property(lambda s:s.__session)
+
def _increment_connect(self):
self.__opencount += 1
return self
+
def _create_transaction(self, parent):
return TLTransaction(self, parent)
+
def _begin(self):
return base.Connection.begin(self)
+
def in_transaction(self):
return self.session.in_transaction()
+
def begin(self):
return self.session.begin()
+
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
self.__opencount -= 1
+
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
-
+
class TLTransaction(base.Transaction):
def _commit_impl(self):
base.Transaction.commit(self)
+
def _rollback_impl(self):
base.Transaction.rollback(self)
+
def commit(self):
self.connection.session.commit()
+
def rollback(self):
self.connection.session.rollback()
-
+
class TLEngine(base.Engine):
- """an Engine that includes support for thread-local managed transactions. This engine
- is better suited to be used with threadlocal Pool object."""
+ """An Engine that includes support for thread-local managed transactions.
+
+ This engine is better suited to be used with threadlocal Pool
+ object.
+ """
+
def __init__(self, *args, **kwargs):
- """the TLEngine relies upon the ConnectionProvider having "threadlocal" behavior,
- so that once a connection is checked out for the current thread, you get that same connection
- repeatedly."""
+ """The TLEngine relies upon the ConnectionProvider having
+ "threadlocal" behavior, so that once a connection is checked out
+ for the current thread, you get that same connection
+ repeatedly.
+ """
+
super(TLEngine, self).__init__(*args, **kwargs)
self.context = util.ThreadLocal()
+
def raw_connection(self):
- """returns a DBAPI connection."""
+ """Return a DBAPI connection."""
+
return self.connection_provider.get_connection()
+
def connect(self, **kwargs):
- """returns a Connection that is not thread-locally scoped. this is the equilvalent to calling
- "connect()" on a ComposedSQLEngine."""
+ """Return a Connection that is not thread-locally scoped.
+
+ This is the equivalent to calling ``connect()`` on a
+ ComposedSQLEngine.
+ """
+
return base.Connection(self, self.connection_provider.unique_connection())
def _session(self):
if not hasattr(self.context, 'session'):
self.context.session = TLSession(self)
return self.context.session
+
session = property(_session, doc="returns the current thread's TLSession")
def contextual_connect(self, **kwargs):
- """returns a TLConnection which is thread-locally scoped."""
+ """Return a TLConnection which is thread-locally scoped."""
+
return self.session.get_connection(**kwargs)
-
+
def begin(self):
return self.session.begin()
+
def commit(self):
self.session.commit()
+
def rollback(self):
self.session.rollback()
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
index 2345a399b9..6a0180d620 100644
--- a/lib/sqlalchemy/engine/url.py
+++ b/lib/sqlalchemy/engine/url.py
@@ -3,30 +3,40 @@ import cgi
import urllib
from sqlalchemy import exceptions
-"""provides the URL object as well as the make_url parsing function."""
+"""Provide the URL object as well as the make_url parsing function."""
class URL(object):
- """represents the components of a URL used to connect to a database.
-
- This object is suitable to be passed directly to a create_engine() call.
- The fields of the URL are parsed from a string by the module-level make_url() function.
- the string format of the URL is an RFC-1738-style string.
-
+ """Represent the components of a URL used to connect to a database.
+
+ This object is suitable to be passed directly to a ``create_engine()``
+ call. The fields of the URL are parsed from a string by the
+ ``module-level make_url()`` function. the string format of the URL is
+ an RFC-1738-style string.
+
Attributes on URL include:
-
+
drivername
-
+ The name of the database backend.
+
username
-
+ The user name for the connection.
+
password
-
+ His password.
+
host
-
+ The name of the host.
+
port
-
+ The port number.
+
database
-
- query - a dictionary containing key/value pairs representing the URL's query string."""
+ The database.
+
+ query
+ A dictionary containing key/value pairs representing the URL's query string.
+ """
+
def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None):
self.drivername = drivername
self.username = username
@@ -38,6 +48,7 @@ class URL(object):
self.port = None
self.database= database
self.query = query or {}
+
def __str__(self):
s = self.drivername + "://"
if self.username is not None:
@@ -56,14 +67,21 @@ class URL(object):
keys.sort()
s += '?' + "&".join(["%s=%s" % (k, self.query[k]) for k in keys])
return s
+
def get_module(self):
- """return the SQLAlchemy database module corresponding to this URL's driver name."""
+ """Return the SQLAlchemy database module corresponding to this URL's driver name."""
+
return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+
def translate_connect_args(self, names):
- """translate this URL's attributes into a dictionary of connection arguments.
-
- given a list of argument names corresponding to the URL attributes ('host', 'database', 'username', 'password', 'port'),
- will assemble the attribute values of this URL into the dictionary using the given names."""
+ """Translate this URL's attributes into a dictionary of connection arguments.
+
+ Given a list of argument names corresponding to the URL
+ attributes (`host`, `database`, `username`, `password`,
+ `port`), will assemble the attribute values of this URL into
+ the dictionary using the given names.
+ """
+
a = {}
attribute_names = ['host', 'database', 'username', 'password', 'port']
for n in names:
@@ -73,18 +91,19 @@ class URL(object):
if getattr(self, sname, None):
a[n] = getattr(self, sname)
return a
-
def make_url(name_or_url):
- """given a string or unicode instance, produces a new URL instance.
-
- the given string is parsed according to the rfc1738 spec.
- if an existing URL object is passed, just returns the object."""
+ """Given a string or unicode instance, produce a new URL instance.
+
+ The given string is parsed according to the rfc1738 spec. If an
+ existing URL object is passed, just returns the object.
+ """
+
if isinstance(name_or_url, basestring):
return _parse_rfc1738_args(name_or_url)
else:
return name_or_url
-
+
def _parse_rfc1738_args(name):
pattern = re.compile(r'''
(\w+)://
@@ -99,7 +118,7 @@ def _parse_rfc1738_args(name):
(?:/(.*))?
'''
, re.X)
-
+
m = pattern.match(name)
if m is not None:
(name, username, password, host, port, database) = m.group(1, 2, 3, 4, 5, 6)
@@ -124,4 +143,3 @@ def _parse_keyvalue_args(name):
return URL(name, *opts)
else:
return None
-
diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exceptions.py
index 7e3883aec8..e9d7d0c442 100644
--- a/lib/sqlalchemy/exceptions.py
+++ b/lib/sqlalchemy/exceptions.py
@@ -6,59 +6,77 @@
class SQLAlchemyError(Exception):
- """generic error class"""
+ """Generic error class."""
+
pass
-
+
class SQLError(SQLAlchemyError):
- """raised when the execution of a SQL statement fails. includes accessors
- for the underlying exception, as well as the SQL and bind parameters"""
+ """Raised when the execution of a SQL statement fails.
+
+ Includes accessors for the underlying exception, as well as the
+ SQL and bind parameters.
+ """
+
def __init__(self, statement, params, orig):
SQLAlchemyError.__init__(self, "(%s) %s"% (orig.__class__.__name__, str(orig)))
self.statement = statement
self.params = params
self.orig = orig
+
def __str__(self):
return SQLAlchemyError.__str__(self) + " " + repr(self.statement) + " " + repr(self.params)
class ArgumentError(SQLAlchemyError):
- """raised for all those conditions where invalid arguments are sent to constructed
- objects. This error generally corresponds to construction time state errors."""
+ """Raised for all those conditions where invalid arguments are
+ sent to constructed objects. This error generally corresponds to
+ construction time state errors.
+ """
+
pass
class TimeoutError(SQLAlchemyError):
- """raised when a connection pool times out on getting a connection"""
+ """Raised when a connection pool times out on getting a connection."""
+
pass
class ConcurrentModificationError(SQLAlchemyError):
- """raised when a concurrent modification condition is detected"""
+ """Raised when a concurrent modification condition is detected."""
+
pass
-
+
class FlushError(SQLAlchemyError):
- """raised when an invalid condition is detected upon a flush()"""
+ """Raised when an invalid condition is detected upon a ``flush()``."""
pass
-
+
class InvalidRequestError(SQLAlchemyError):
- """sqlalchemy was asked to do something it cant do, return nonexistent data, etc.
- This error generally corresponds to runtime state errors."""
+ """SQLAlchemy was asked to do something it can't do, return
+ nonexistent data, etc.
+
+ This error generally corresponds to runtime state errors.
+ """
+
pass
class NoSuchTableError(InvalidRequestError):
- """sqlalchemy was asked to load a table's definition from the database,
- but the table doesn't exist."""
+ """SQLAlchemy was asked to load a table's definition from the
+ database, but the table doesn't exist.
+ """
+
pass
class AssertionError(SQLAlchemyError):
- """corresponds to internal state being detected in an invalid state"""
+ """Corresponds to internal state being detected in an invalid state."""
+
pass
class NoSuchColumnError(KeyError, SQLAlchemyError):
- """raised by RowProxy when a nonexistent column is requested from a row"""
+ """Raised by ``RowProxy`` when a nonexistent column is requested from a row."""
+
pass
-
+
class DBAPIError(SQLAlchemyError):
- """something weird happened with a particular DBAPI version"""
+ """Something weird happened with a particular DBAPI version."""
+
def __init__(self, message, orig):
SQLAlchemyError.__init__(self, "(%s) (%s) %s"% (message, orig.__class__.__name__, str(orig)))
self.orig = orig
-
-
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index c9160ded46..907ef19dcf 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -1,47 +1,60 @@
-"""contains the AssociationProxy class, a Python property object which
-provides transparent proxied access to the endpoint of an association object.
+"""Contain the ``AssociationProxy`` class.
-See the example examples/association/proxied_association.py.
+The ``AssociationProxy`` is a Python property object which provides
+transparent proxied access to the endpoint of an association object.
+
+See the example ``examples/association/proxied_association.py``.
"""
from sqlalchemy.orm import class_mapper
class AssociationProxy(object):
- """a property object that automatically sets up AssociationLists on a parent object."""
+ """A property object that automatically sets up ``AssociationLists`` on a parent object."""
+
def __init__(self, targetcollection, attr, creator=None):
- """create a new association property.
-
- targetcollection - the attribute name which stores the collection of Associations
-
- attr - name of the attribute on the Association in which to get/set target values
-
- creator - optional callable which is used to create a new association object. this
- callable is given a single argument which is an instance of the "proxied" object.
- if creator is not given, the association object is created using the class associated
- with the targetcollection attribute, using its __init__() constructor and setting
- the proxied attribute.
+ """Create a new association property.
+
+ targetcollection
+ The attribute name which stores the collection of Associations.
+
+ attr
+ Name of the attribute on the Association in which to get/set target values.
+
+ creator
+ Optional callable which is used to create a new association
+ object. This callable is given a single argument which is
+ an instance of the *proxied* object. If creator is not
+ given, the association object is created using the class
+ associated with the targetcollection attribute, using its
+ ``__init__()`` constructor and setting the proxied
+ attribute.
"""
self.targetcollection = targetcollection
self.attr = attr
self.creator = creator
+
def __init_deferred(self):
prop = class_mapper(self._owner_class).props[self.targetcollection]
self._cls = prop.mapper.class_
self._uselist = prop.uselist
+
def _get_class(self):
try:
return self._cls
except AttributeError:
self.__init_deferred()
return self._cls
+
def _get_uselist(self):
try:
return self._uselist
except AttributeError:
self.__init_deferred()
return self._uselist
+
cls = property(_get_class)
uselist = property(_get_uselist)
+
def create(self, target, **kw):
if self.creator is not None:
return self.creator(target, **kw)
@@ -49,6 +62,7 @@ class AssociationProxy(object):
assoc = self.cls(**kw)
setattr(assoc, self.attr, target)
return assoc
+
def __get__(self, obj, owner):
self._owner_class = owner
if obj is None:
@@ -63,34 +77,44 @@ class AssociationProxy(object):
return a
else:
return getattr(getattr(obj, self.targetcollection), self.attr)
+
def __set__(self, obj, value):
if self.uselist:
setattr(obj, self.targetcollection, [self.create(x) for x in value])
else:
setattr(obj, self.targetcollection, self.create(value))
+
def __del__(self, obj):
delattr(obj, self.targetcollection)
class _AssociationList(object):
- """generic proxying list which proxies list operations to a different
- list-holding attribute of the parent object, converting Association objects
- to and from a target attribute on each Association object."""
+ """Generic proxying list which proxies list operations to a
+ different list-holding attribute of the parent object, converting
+ Association objects to and from a target attribute on each
+ Association object.
+ """
+
def __init__(self, proxy, parent):
- """create a new AssociationList."""
+ """Create a new ``AssociationList``."""
self.proxy = proxy
self.parent = parent
+
def append(self, item, **kw):
a = self.proxy.create(item, **kw)
getattr(self.parent, self.proxy.targetcollection).append(a)
+
def __iter__(self):
return iter([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
+
def __repr__(self):
return repr([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
+
def __len__(self):
return len(getattr(self.parent, self.proxy.targetcollection))
+
def __getitem__(self, index):
return getattr(getattr(self.parent, self.proxy.targetcollection)[index], self.proxy.attr)
+
def __setitem__(self, index, value):
a = self.proxy.create(item)
getattr(self.parent, self.proxy.targetcollection)[index] = a
-
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py
index c7e707f8d8..b81702fc45 100644
--- a/lib/sqlalchemy/ext/proxy.py
+++ b/lib/sqlalchemy/ext/proxy.py
@@ -10,26 +10,41 @@ __all__ = ['BaseProxyEngine', 'AutoConnectEngine', 'ProxyEngine']
class BaseProxyEngine(sql.Executor):
"""Basis for all proxy engines."""
-
+
def get_engine(self):
raise NotImplementedError
def set_engine(self, engine):
raise NotImplementedError
-
+
engine = property(lambda s:s.get_engine(), lambda s,e:s.set_engine(e))
-
+
def execute_compiled(self, *args, **kwargs):
- """this method is required to be present as it overrides the execute_compiled present in sql.Engine"""
- return self.get_engine().execute_compiled(*args, **kwargs)
- def compiler(self, *args, **kwargs):
- """this method is required to be present as it overrides the compiler method present in sql.Engine"""
- return self.get_engine().compiler(*args, **kwargs)
+ """Override superclass behaviour.
+
+ This method is required to be present as it overrides the
+ `execute_compiled` present in ``sql.Engine``.
+ """
+
+ return self.get_engine().execute_compiled(*args, **kwargs)
+
+ def compiler(self, *args, **kwargs):
+ """Override superclass behaviour.
+
+ This method is required to be present as it overrides the
+ `compiler` method present in ``sql.Engine``.
+ """
+
+ return self.get_engine().compiler(*args, **kwargs)
def __getattr__(self, attr):
- """provides proxying for methods that are not otherwise present on this BaseProxyEngine. Note
- that methods which are present on the base class sql.Engine will *not* be proxied through this,
- and must be explicit on this class."""
+ """Provide proxying for methods that are not otherwise present on this ``BaseProxyEngine``.
+
+ Note that methods which are present on the base class
+ ``sql.Engine`` will **not** be proxied through this, and must
+ be explicit on this class.
+ """
+
# call get_engine() to give subclasses a chance to change
# connection establishment behavior
e = self.get_engine()
@@ -38,16 +53,15 @@ class BaseProxyEngine(sql.Executor):
raise AttributeError("No connection established in ProxyEngine: "
" no access to %s" % attr)
-
class AutoConnectEngine(BaseProxyEngine):
"""An SQLEngine proxy that automatically connects when necessary."""
-
+
def __init__(self, dburi, **kwargs):
BaseProxyEngine.__init__(self)
self.dburi = dburi
self.kwargs = kwargs
self._engine = None
-
+
def get_engine(self):
if self._engine is None:
if callable(self.dburi):
@@ -60,7 +74,7 @@ class AutoConnectEngine(BaseProxyEngine):
class ProxyEngine(BaseProxyEngine):
"""Engine proxy for lazy and late initialization.
-
+
This engine will delegate access to a real engine set with connect().
"""
@@ -89,7 +103,7 @@ class ProxyEngine(BaseProxyEngine):
except KeyError:
map[key] = create_engine(*args, **kwargs)
self.storage.engine = map[key]
-
+
def get_engine(self):
if not hasattr(self.storage, 'engine') or self.storage.engine is None:
raise AttributeError("No connection established")
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
index eab2aa688d..d65a02f01d 100644
--- a/lib/sqlalchemy/ext/selectresults.py
+++ b/lib/sqlalchemy/ext/selectresults.py
@@ -2,7 +2,7 @@ import sqlalchemy.sql as sql
import sqlalchemy.orm as orm
class SelectResultsExt(orm.MapperExtension):
- """a MapperExtension that provides SelectResults functionality for the
+ """a MapperExtension that provides SelectResults functionality for the
results of query.select_by() and query.select()"""
def select_by(self, query, *args, **params):
return SelectResults(query, query.join_by(*args, **params))
@@ -13,14 +13,19 @@ class SelectResultsExt(orm.MapperExtension):
return SelectResults(query, arg, ops=kwargs)
class SelectResults(object):
- """Builds a query one component at a time via separate method calls,
- each call transforming the previous SelectResults instance into a new SelectResults
- instance with further limiting criterion added. When interpreted
- in an iterator context (such as via calling list(selectresults)), executes the query."""
-
+ """Build a query one component at a time via separate method
+ calls, each call transforming the previous ``SelectResults``
+ instance into a new ``SelectResults`` instance with further
+ limiting criterion added. When interpreted in an iterator context
+ (such as via calling ``list(selectresults)``), executes the query.
+ """
+
def __init__(self, query, clause=None, ops={}, joinpoint=None):
- """constructs a new SelectResults using the given Query object and optional WHERE
- clause. ops is an optional dictionary of bind parameter values."""
+ """Construct a new ``SelectResults`` using the given ``Query``
+ object and optional ``WHERE`` clause. `ops` is an optional
+ dictionary of bind parameter values.
+ """
+
self._query = query
self._clause = clause
self._ops = {}
@@ -28,23 +33,24 @@ class SelectResults(object):
self._joinpoint = joinpoint or (self._query.table, self._query.mapper)
def options(self,*args, **kwargs):
- """transform the original mapper query form to an alternate form
-
- See also Query.options
+ """Transform the original mapper query form to an alternate form
+ See also ``Query.options``.
"""
+
self._query = self._query.options(*args, **kwargs)
def count(self):
- """executes the SQL count() function against the SelectResults criterion."""
+ """Execute the SQL ``count()`` function against the ``SelectResults`` criterion."""
+
return self._query.count(self._clause, **self._ops)
def _col_aggregate(self, col, func):
- """executes func() function against the given column
+ """Execute ``func()`` function against the given column.
- For performance, only use subselect if order_by attribute is set.
-
+ For performance, only use subselect if `order_by` attribute is set.
"""
+
if self._ops.get('order_by'):
s1 = sql.select([col], self._clause, **self._ops).alias('u')
return sql.select([func(s1.corresponding_column(col))]).scalar()
@@ -52,95 +58,115 @@ class SelectResults(object):
return sql.select([func(col)], self._clause, **self._ops).scalar()
def min(self, col):
- """executes the SQL min() function against the given column"""
+ """Execute the SQL ``min()`` function against the given column."""
+
return self._col_aggregate(col, sql.func.min)
def max(self, col):
- """executes the SQL max() function against the given column"""
+ """Execute the SQL ``max()`` function against the given column."""
+
return self._col_aggregate(col, sql.func.max)
def sum(self, col):
- """executes the SQL sum() function against the given column"""
+ """Execute the SQL ``sum``() function against the given column."""
+
return self._col_aggregate(col, sql.func.sum)
def avg(self, col):
- """executes the SQL avg() function against the given column"""
+ """Execute the SQL ``avg()`` function against the given column."""
+
return self._col_aggregate(col, sql.func.avg)
def clone(self):
- """creates a copy of this SelectResults."""
+ """Create a copy of this ``SelectResults``."""
+
return SelectResults(self._query, self._clause, self._ops.copy(), self._joinpoint)
-
+
def filter(self, clause):
- """applies an additional WHERE clause against the query."""
+ """Apply an additional ``WHERE`` clause against the query."""
+
new = self.clone()
new._clause = sql.and_(self._clause, clause)
return new
def select(self, clause):
return self.filter(clause)
-
+
def select_by(self, *args, **kwargs):
return self.filter(self._query._join_by(args, kwargs, start=self._joinpoint[1]))
-
+
def order_by(self, order_by):
- """apply an ORDER BY to the query."""
+ """Apply an ``ORDER BY`` to the query."""
+
new = self.clone()
new._ops['order_by'] = order_by
return new
def limit(self, limit):
- """apply a LIMIT to the query."""
+ """Apply a ``LIMIT`` to the query."""
+
return self[:limit]
def offset(self, offset):
- """apply an OFFSET to the query."""
+ """Apply an ``OFFSET`` to the query."""
+
return self[offset:]
def distinct(self):
- """applies a DISTINCT to the query"""
+ """Apply a ``DISTINCT`` to the query."""
+
new = self.clone()
new._ops['distinct'] = True
return new
-
+
def list(self):
- """return the results represented by this SelectResults as a list.
-
- this results in an execution of the underlying query."""
+ """Return the results represented by this ``SelectResults`` as a list.
+
+ This results in an execution of the underlying query.
+ """
+
return list(self)
-
+
def select_from(self, from_obj):
- """set the from_obj parameter of the query to a specific table or set of tables.
-
- from_obj is a list."""
+ """Set the `from_obj` parameter of the query.
+
+ `from_obj` is a list of one or more tables.
+ """
+
new = self.clone()
new._ops['from_obj'] = from_obj
return new
-
+
def join_to(self, prop):
- """join the table of this SelectResults to the table located against the given property name.
-
- subsequent calls to join_to or outerjoin_to will join against the rightmost table located from the
- previous join_to or outerjoin_to call, searching for the property starting with the rightmost mapper
- last located."""
+ """Join the table of this ``SelectResults`` to the table located against the given property name.
+
+ Subsequent calls to join_to or outerjoin_to will join against
+ the rightmost table located from the previous `join_to` or
+ `outerjoin_to` call, searching for the property starting with
+ the rightmost mapper last located.
+ """
+
new = self.clone()
(clause, mapper) = self._join_to(prop, outerjoin=False)
new._ops['from_obj'] = [clause]
new._joinpoint = (clause, mapper)
return new
-
+
def outerjoin_to(self, prop):
- """outer join the table of this SelectResults to the table located against the given property name.
-
- subsequent calls to join_to or outerjoin_to will join against the rightmost table located from the
- previous join_to or outerjoin_to call, searching for the property starting with the rightmost mapper
- last located."""
+ """Outer join the table of this ``SelectResults`` to the table located against the given property name.
+
+ Subsequent calls to join_to or outerjoin_to will join against
+ the rightmost table located from the previous ``join_to` or
+ `outerjoin_to` call, searching for the property starting with
+ the rightmost mapper last located.
+ """
+
new = self.clone()
(clause, mapper) = self._join_to(prop, outerjoin=True)
new._ops['from_obj'] = [clause]
new._joinpoint = (clause, mapper)
return new
-
+
def _join_to(self, prop, outerjoin=False):
[keys,p] = self._query._locate_prop(prop, start=self._joinpoint[1])
clause = self._joinpoint[0]
@@ -153,10 +179,10 @@ class SelectResults(object):
clause = clause.join(prop.select_table, prop.get_join(mapper))
mapper = prop.mapper
return (clause, mapper)
-
+
def compile(self):
return self._query.compile(self._clause, **self._ops)
-
+
def __getitem__(self, item):
if isinstance(item, slice):
start = item.start
@@ -178,6 +204,6 @@ class SelectResults(object):
return res
else:
return list(self[item:item+1])[0]
-
+
def __iter__(self):
return iter(self._query.select_whereclause(self._clause, **self._ops))
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
index f431f87c7f..2f81e55d2c 100644
--- a/lib/sqlalchemy/ext/sessioncontext.py
+++ b/lib/sqlalchemy/ext/sessioncontext.py
@@ -4,35 +4,41 @@ from sqlalchemy.orm.mapper import MapperExtension
__all__ = ['SessionContext', 'SessionContextExt']
class SessionContext(object):
- """A simple wrapper for ScopedRegistry that provides a "current" property
- which can be used to get, set, or remove the session in the current scope.
-
- By default this object provides thread-local scoping, which is the default
- scope provided by sqlalchemy.util.ScopedRegistry.
-
- Usage:
- engine = create_engine(...)
- def session_factory():
- return Session(bind_to=engine)
- context = SessionContext(session_factory)
-
- s = context.current # get thread-local session
- context.current = Session(bind_to=other_engine) # set current session
- del context.current # discard the thread-local session (a new one will
- # be created on the next call to context.current)
+ """A simple wrapper for ``ScopedRegistry`` that provides a
+ `current` property which can be used to get, set, or remove the
+ session in the current scope.
+
+ By default this object provides thread-local scoping, which is the
+ default scope provided by sqlalchemy.util.ScopedRegistry.
+
+ Usage::
+
+ engine = create_engine(...)
+ def session_factory():
+ return Session(bind_to=engine)
+ context = SessionContext(session_factory)
+
+ s = context.current # get thread-local session
+ context.current = Session(bind_to=other_engine) # set current session
+ del context.current # discard the thread-local session (a new one will
+ # be created on the next call to context.current)
"""
+
def __init__(self, session_factory, scopefunc=None):
self.registry = ScopedRegistry(session_factory, scopefunc)
super(SessionContext, self).__init__()
def get_current(self):
return self.registry()
+
def set_current(self, session):
self.registry.set(session)
+
def del_current(self):
self.registry.clear()
+
current = property(get_current, set_current, del_current,
- """Property used to get/set/del the session in the current scope""")
+ """Property used to get/set/del the session in the current scope.""")
def _get_mapper_extension(self):
try:
@@ -40,16 +46,17 @@ class SessionContext(object):
except AttributeError:
self._extension = ext = SessionContextExt(self)
return ext
+
mapper_extension = property(_get_mapper_extension,
- doc="""get a mapper extension that implements get_session using this context""")
+ doc="""Get a mapper extension that implements `get_session` using this context.""")
class SessionContextExt(MapperExtension):
- """a mapper extionsion that provides sessions to a mapper using SessionContext"""
+ """A mapper extension that provides sessions to a mapper using ``SessionContext``."""
def __init__(self, context):
MapperExtension.__init__(self)
self.context = context
-
+
def get_session(self):
return self.context.current
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index 5fb42df23d..b899c043d4 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -2,28 +2,30 @@
Introduction
============
-SqlSoup provides a convenient way to access database tables without having
-to declare table or mapper classes ahead of time.
+SqlSoup provides a convenient way to access database tables without
+having to declare table or mapper classes ahead of time.
Suppose we have a database with users, books, and loans tables
-(corresponding to the PyWebOff dataset, if you're curious).
-For testing purposes, we'll create this db as follows:
+(corresponding to the PyWebOff dataset, if you're curious). For
+testing purposes, we'll create this db as follows::
>>> from sqlalchemy import create_engine
>>> e = create_engine('sqlite:///:memory:')
>>> for sql in _testsql: e.execute(sql) #doctest: +ELLIPSIS
<...
-Creating a SqlSoup gateway is just like creating an SqlAlchemy engine:
+Creating a SqlSoup gateway is just like creating an SQLAlchemy
+engine::
>>> from sqlalchemy.ext.sqlsoup import SqlSoup
>>> db = SqlSoup('sqlite:///:memory:')
-or, you can re-use an existing metadata:
+or, you can re-use an existing metadata::
>>> db = SqlSoup(BoundMetaData(e))
-You can optionally specify a schema within the database for your SqlSoup:
+You can optionally specify a schema within the database for your
+SqlSoup::
# >>> db.schema = myschemaname
@@ -31,33 +33,34 @@ You can optionally specify a schema within the database for your SqlSoup:
Loading objects
===============
-Loading objects is as easy as this:
+Loading objects is as easy as this::
>>> users = db.users.select()
>>> users.sort()
>>> users
[MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0), MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1)]
-Of course, letting the database do the sort is better (".c" is short for ".columns"):
+Of course, letting the database do the sort is better (".c" is short for ".columns")::
>>> db.users.select(order_by=[db.users.c.name])
[MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1), MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)]
-Field access is intuitive:
+Field access is intuitive::
>>> users[0].email
u'student@example.edu'
-Of course, you don't want to load all users very often. Let's add a WHERE clause.
-Let's also switch the order_by to DESC while we're at it.
+Of course, you don't want to load all users very often. Let's add a
+WHERE clause. Let's also switch the order_by to DESC while we're at
+it::
>>> from sqlalchemy import or_, and_, desc
>>> where = or_(db.users.c.name=='Bhargan Basepair', db.users.c.email=='student@example.edu')
>>> db.users.select(where, order_by=[desc(db.users.c.name)])
[MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0), MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1)]
-You can also use the select...by methods if you're querying on a single column.
-This allows using keyword arguments as column names:
+You can also use the select...by methods if you're querying on a
+single column. This allows using keyword arguments as column names::
>>> db.users.selectone_by(name='Bhargan Basepair')
MappedUsers(name='Bhargan Basepair',email='basepair@example.edu',password='basepair',classname=None,admin=1)
@@ -66,36 +69,56 @@ This allows using keyword arguments as column names:
Select variants
---------------
-All the SqlAlchemy Query select variants are available.
-Here's a quick summary of these methods:
+All the SQLAlchemy Query select variants are available. Here's a
+quick summary of these methods::
-- get(PK): load a single object identified by its primary key (either a scalar, or a tuple)
-- select(Clause, \*\*kwargs): perform a select restricted by the Clause argument; returns a list of objects. The most common clause argument takes the form "db.tablename.c.columname == value." The most common optional argument is order_by.
-- select_by(\*\*params): select methods ending with _by allow using bare column names. (columname=value) This feels more natural to most Python programmers; the downside is you can't specify order_by or other select options.
-- selectfirst, selectfirst_by: returns only the first object found; equivalent to select(...)[0] or select_by(...)[0], except None is returned if no rows are selected.
-- selectone, selectone_by: like selectfirst or selectfirst_by, but raises if less or more than one object is selected.
-- count, count_by: returns an integer count of the rows selected.
+- ``get(PK)``: load a single object identified by its primary key
+ (either a scalar, or a tuple)
-See the SqlAlchemy documentation for details:
+- ``select(Clause, **kwargs)``: perform a select restricted by the
+ `Clause` argument; returns a list of objects. The most common clause
+ argument takes the form ``db.tablename.c.columname == value``. The
+ most common optional argument is `order_by`.
-- http://www.sqlalchemy.org/docs/datamapping.myt#datamapping_query for general info and examples,
-- http://www.sqlalchemy.org/docs/sqlconstruction.myt for details on constructing WHERE clauses.
+- ``select_by(**params)``: select methods ending with ``_by`` allow
+ using bare column names (``columname=value``). This feels more
+ natural to most Python programmers; the downside is you can't
+ specify ``order_by`` or other select options.
+
+- ``selectfirst``, ``selectfirst_by``: returns only the first object
+ found; equivalent to ``select(...)[0]`` or ``select_by(...)[0]``,
+ except None is returned if no rows are selected.
+
+- ``selectone``, ``selectone_by``: like ``selectfirst`` or
+ ``selectfirst_by``, but raises if less or more than one object is
+ selected.
+
+- ``count``, ``count_by``: returns an integer count of the rows
+ selected.
+
+See the SQLAlchemy documentation for details, `datamapping query`__
+for general info and examples, `sql construction`__ for details on
+constructing ``WHERE`` clauses.
+
+__ http://www.sqlalchemy.org/docs/datamapping.myt#datamapping_query
+__http://www.sqlalchemy.org/docs/sqlconstruction.myt
Modifying objects
=================
-Modifying objects is intuitive:
+Modifying objects is intuitive::
>>> user = _
>>> user.email = 'basepair+nospam@example.edu'
>>> db.flush()
-(SqlSoup leverages the sophisticated SqlAlchemy unit-of-work code, so
-multiple updates to a single object will be turned into a single UPDATE
-statement when you flush.)
+(SqlSoup leverages the sophisticated SQLAlchemy unit-of-work code, so
+multiple updates to a single object will be turned into a single
+``UPDATE`` statement when you flush.)
-To finish covering the basics, let's insert a new loan, then delete it:
+To finish covering the basics, let's insert a new loan, then delete
+it::
>>> book_id = db.books.selectfirst(db.books.c.title=='Regional Variation in Moss').id
>>> db.loans.insert(book_id=book_id, user_name=user.name)
@@ -106,34 +129,39 @@ To finish covering the basics, let's insert a new loan, then delete it:
>>> db.delete(loan)
>>> db.flush()
-You can also delete rows that have not been loaded as objects. Let's do our
-insert/delete cycle once more, this time using the loans table's delete
-method. (For SQLAlchemy experts: note that no flush() call is required since
-this delete acts at the SQL level, not at the Mapper level.) The same
-where-clause construction rules apply here as to the select methods.
+You can also delete rows that have not been loaded as objects. Let's
+do our insert/delete cycle once more, this time using the loans
+table's delete method. (For SQLAlchemy experts: note that no flush()
+call is required since this delete acts at the SQL level, not at the
+Mapper level.) The same where-clause construction rules apply here as
+to the select methods.
+
+::
>>> db.loans.insert(book_id=book_id, user_name=user.name)
MappedLoans(book_id=2,user_name='Bhargan Basepair',loan_date=None)
>>> db.flush()
>>> db.loans.delete(db.loans.c.book_id==2)
-You can similarly update multiple rows at once. This will change the book_id
-to 1 in all loans whose book_id is 2:
+You can similarly update multiple rows at once. This will change the
+book_id to 1 in all loans whose book_id is 2::
>>> db.loans.update(db.loans.c.book_id==2, book_id=1)
>>> db.loans.select_by(db.loans.c.book_id==1)
[MappedLoans(book_id=1,user_name='Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0))]
-
+
Joins
=====
-Occasionally, you will want to pull out a lot of data from related tables all at
-once. In this situation, it is far
-more efficient to have the database perform the necessary join. (Here
-we do not have "a lot of data," but hopefully the concept is still clear.)
-SQLAlchemy is smart enough to recognize that loans has a foreign key
-to users, and uses that as the join condition automatically.
+Occasionally, you will want to pull out a lot of data from related
+tables all at once. In this situation, it is far more efficient to
+have the database perform the necessary join. (Here we do not have *a
+lot of data* but hopefully the concept is still clear.) SQLAlchemy is
+smart enough to recognize that loans has a foreign key to users, and
+uses that as the join condition automatically.
+
+::
>>> join1 = db.join(db.users, db.loans, isouter=True)
>>> join1.select_by(name='Joe Student')
@@ -142,25 +170,26 @@ to users, and uses that as the join condition automatically.
If you're unfortunate enough to be using MySQL with the default MyISAM
storage engine, you'll have to specify the join condition manually,
since MyISAM does not store foreign keys. Here's the same join again,
-with the join condition explicitly specified:
+with the join condition explicitly specified::
>>> db.join(db.users, db.loans, db.users.c.name==db.loans.c.user_name, isouter=True)
-You can compose arbitrarily complex joins by combining Join objects with
-tables or other joins. Here we combine our first join with the books table:
+You can compose arbitrarily complex joins by combining Join objects
+with tables or other joins. Here we combine our first join with the
+books table::
>>> join2 = db.join(join1, db.books)
>>> join2.select()
[MappedJoin(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0,book_id=1,user_name='Joe Student',loan_date=datetime.datetime(2006, 7, 12, 0, 0),id=1,title='Mustards I Have Known',published_year='1989',authors='Jones')]
-If you join tables that have an identical column name, wrap your join with "with_labels",
-to disambiguate columns with their table name:
+If you join tables that have an identical column name, wrap your join
+with `with_labels`, to disambiguate columns with their table name::
>>> db.with_labels(join1).c.keys()
['users_name', 'users_email', 'users_password', 'users_classname', 'users_admin', 'loans_book_id', 'loans_user_name', 'loans_loan_date']
-You can also join directly to a labeled object:
+You can also join directly to a labeled object::
>>> labeled_loans = db.with_labels(db.loans)
>>> db.join(db.users, labeled_loans, isouter=True).c.keys()
@@ -173,27 +202,30 @@ Advanced Use
Accessing the Session
---------------------
-SqlSoup uses a SessionContext to provide thread-local sessions. You can
-get a reference to the current one like this:
+SqlSoup uses a SessionContext to provide thread-local sessions. You
+can get a reference to the current one like this::
>>> from sqlalchemy.ext.sqlsoup import objectstore
>>> session = objectstore.current
-Now you have access to all the standard session-based SA features, such
-as transactions. (SqlSoup's flush() is normally transactionalized, but
-you can perform manual transaction management if you need a transaction
-to span multiple flushes.)
+Now you have access to all the standard session-based SA features,
+such as transactions. (SqlSoup's ``flush()`` is normally
+transactionalized, but you can perform manual transaction management
+if you need a transaction to span multiple flushes.)
Mapping arbitrary Selectables
-----------------------------
-SqlSoup can map any SQLAlchemy Selectable with the map method. Let's map a
-Select object that uses an aggregate function; we'll use the SQLAlchemy Table
-that SqlSoup introspected as the basis. (Since we're not mapping to a simple
-table or join, we need to tell SQLAlchemy how to find the "primary key," which
-just needs to be unique within the select, and not necessarily correspond to a
-"real" PK in the database.)
+SqlSoup can map any SQLAlchemy ``Selectable`` with the map
+method. Let's map a ``Select`` object that uses an aggregate function;
+we'll use the SQLAlchemy ``Table`` that SqlSoup introspected as the
+basis. (Since we're not mapping to a simple table or join, we need to
+tell SQLAlchemy how to find the *primary key* which just needs to be
+unique within the select, and not necessarily correspond to a *real*
+PK in the database.)
+
+::
>>> from sqlalchemy import select, func
>>> b = db.books._table
@@ -202,20 +234,21 @@ just needs to be unique within the select, and not necessarily correspond to a
>>> years_with_count = db.map(s, primary_key=[s.c.published_year])
>>> years_with_count.select_by(published_year='1989')
[MappedBooks(published_year='1989',n=1)]
-
-Obviously if we just wanted to get a list of counts associated with book years
-once, raw SQL is going to be less work. The advantage of mapping a Select is
-reusability, both standalone and in Joins. (And if you go to full SQLAlchemy,
-you can perform mappings like this directly to your object models.)
+
+Obviously if we just wanted to get a list of counts associated with
+book years once, raw SQL is going to be less work. The advantage of
+mapping a Select is reusability, both standalone and in Joins. (And if
+you go to full SQLAlchemy, you can perform mappings like this directly
+to your object models.)
Raw SQL
-------
-You can access the SqlSoup's ``engine`` attribute to compose SQL directly.
-The engine's ``execute`` method corresponds
-to the one of a DBAPI cursor, and returns a ``ResultProxy`` that has ``fetch`` methods
-you would also see on a cursor.
+You can access the SqlSoup's `engine` attribute to compose SQL
+directly. The engine's ``execute`` method corresponds to the one of a
+DBAPI cursor, and returns a ``ResultProxy`` that has ``fetch`` methods
+you would also see on a cursor::
>>> rp = db.engine.execute('select name, email from users order by name')
>>> for name, email in rp.fetchall(): print name, email
@@ -230,14 +263,16 @@ Extra tests
Boring tests here. Nothing of real expository value.
+::
+
>>> db.users.select(db.users.c.classname==None, order_by=[db.users.c.name])
[MappedUsers(name='Bhargan Basepair',email='basepair+nospam@example.edu',password='basepair',classname=None,admin=1), MappedUsers(name='Joe Student',email='student@example.edu',password='student',classname=None,admin=0)]
-
+
>>> db.nopk
Traceback (most recent call last):
...
PKNotFoundError: table 'nopk' does not have a primary key defined
-
+
>>> db.nosuchtable
Traceback (most recent call last):
...
@@ -282,7 +317,7 @@ CREATE TABLE users (
CREATE TABLE loans (
book_id int PRIMARY KEY REFERENCES books(id),
- user_name varchar(32) references users(name)
+ user_name varchar(32) references users(name)
ON DELETE SET NULL ON UPDATE CASCADE,
loan_date datetime DEFAULT current_timestamp
);
@@ -299,7 +334,7 @@ values('Regional Variation in Moss', '1971', 'Flim and Flam');
insert into loans(book_id, user_name, loan_date)
values (
- (select min(id) from books),
+ (select min(id) from books),
(select name from users where name like 'Joe%'),
'2006-07-12 0:0:0')
;
@@ -330,30 +365,37 @@ def _ddl_error(cls):
msg = 'SQLSoup can only modify mapped Tables (found: %s)' \
% cls._table.__class__.__name__
raise InvalidRequestError(msg)
+
class SelectableClassType(type):
def insert(cls, **kwargs):
_ddl_error(cls)
+
def delete(cls, *args, **kwargs):
_ddl_error(cls)
+
def update(cls, whereclause=None, values=None, **kwargs):
_ddl_error(cls)
+
def _selectable(cls):
return cls._table
+
def __getattr__(cls, attr):
if attr == '_query':
# called during mapper init
raise AttributeError()
return getattr(cls._query, attr)
+
class TableClassType(SelectableClassType):
def insert(cls, **kwargs):
o = cls()
o.__dict__.update(kwargs)
return o
+
def delete(cls, *args, **kwargs):
cls._table.delete(*args, **kwargs).execute()
+
def update(cls, whereclause=None, values=None, **kwargs):
cls._table.update(whereclause, values).execute(**kwargs)
-
def _is_outer_join(selectable):
if not isinstance(selectable, sql.Join):
@@ -384,6 +426,7 @@ def class_for_table(selectable, **mapper_kwargs):
klass = TableClassType(mapname, (object,), {})
else:
klass = SelectableClassType(mapname, (object,), {})
+
def __cmp__(self, o):
L = self.__class__.c.keys()
L.sort()
@@ -393,6 +436,7 @@ def class_for_table(selectable, **mapper_kwargs):
except AttributeError:
raise TypeError('unable to compare with %s' % o.__class__)
return cmp(t1, t2)
+
def __repr__(self):
import locale
encoding = locale.getdefaultlocale()[1] or 'ascii'
@@ -403,6 +447,7 @@ def class_for_table(selectable, **mapper_kwargs):
value = value.encode(encoding)
L.append("%s=%r" % (k, value))
return '%s(%s)' % (self.__class__.__name__, ','.join(L))
+
for m in ['__cmp__', '__repr__']:
setattr(klass, m, eval(m))
klass._table = selectable
@@ -416,10 +461,12 @@ def class_for_table(selectable, **mapper_kwargs):
class SqlSoup:
def __init__(self, *args, **kwargs):
+ """Initialize a new ``SqlSoup``.
+
+ `args` may either be an ``SQLEngine`` or a set of arguments
+ suitable for passing to ``create_engine``.
"""
- args may either be an SQLEngine or a set of arguments suitable
- for passing to create_engine
- """
+
# meh, sometimes having method overloading instead of kwargs would be easier
if isinstance(args[0], MetaData):
args = list(args)
@@ -431,15 +478,21 @@ class SqlSoup:
self._metadata = metadata
self._cache = {}
self.schema = None
+
def engine(self):
return self._metadata._engine
+
engine = property(engine)
+
def delete(self, *args, **kwargs):
objectstore.delete(*args, **kwargs)
+
def flush(self):
objectstore.get_session().flush()
+
def clear(self):
objectstore.clear()
+
def map(self, selectable, **kwargs):
try:
t = self._cache[selectable]
@@ -447,12 +500,15 @@ class SqlSoup:
t = class_for_table(selectable, **kwargs)
self._cache[selectable] = t
return t
+
def with_labels(self, item):
# TODO give meaningful aliases
return self.map(item._selectable().select(use_labels=True).alias('foo'))
+
def join(self, *args, **kwargs):
j = join(*args, **kwargs)
return self.map(j)
+
def __getattr__(self, attr):
try:
t = self._cache[attr]
diff --git a/lib/sqlalchemy/logging.py b/lib/sqlalchemy/logging.py
index 7e293ab95d..6f43687079 100644
--- a/lib/sqlalchemy/logging.py
+++ b/lib/sqlalchemy/logging.py
@@ -4,24 +4,26 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""provides a few functions used by instances to turn on/off their logging, including support
-for the usual "echo" parameter. Control of logging for SA can be performed from the regular
-python logging module. The regular dotted module namespace is used, starting at 'sqlalchemy'.
-For class-level logging, the class name is appended, and for instance-level logging, the hex
-id of the instance is appended.
+"""Provides a few functions used by instances to turn on/off their
+logging, including support for the usual "echo" parameter.
-The "echo" keyword parameter which is available on some SA objects corresponds to an instance-level
-logger for that instance.
+Control of logging for SA can be performed from the regular python
+logging module. The regular dotted module namespace is used, starting
+at 'sqlalchemy'. For class-level logging, the class name is appended,
+and for instance-level logging, the hex id of the instance is
+appended.
-E.g.:
+The "echo" keyword parameter which is available on some SA objects
+corresponds to an instance-level logger for that instance.
+
+E.g.::
engine.echo = True
-
-is equivalent to:
+
+is equivalent to::
import logging
logging.getLogger('sqlalchemy.engine.Engine.%s' % hex(id(engine))).setLevel(logging.DEBUG)
-
"""
import sys
@@ -44,12 +46,12 @@ def default_logging(name):
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(name)s %(message)s'))
rootlogger.addHandler(handler)
-def _get_instance_name(instance):
+def _get_instance_name(instance):
# since getLogger() does not have any way of removing logger objects from memory,
# instance logging displays the instance id as a modulus of 16 to prevent endless memory growth
# also speeds performance as logger initialization is apparently slow
return instance.__class__.__module__ + "." + instance.__class__.__name__ + ".0x.." + hex(id(instance))[-2:]
-
+
def instance_logger(instance):
return logging.getLogger(_get_instance_name(instance))
@@ -58,9 +60,10 @@ def class_logger(cls):
def is_debug_enabled(logger):
return logger.isEnabledFor(logging.DEBUG)
+
def is_info_enabled(logger):
return logger.isEnabledFor(logging.INFO)
-
+
class echo_property(object):
level_map={logging.DEBUG : "debug", logging.INFO:True}
def __get__(self, instance, owner):
@@ -72,5 +75,3 @@ class echo_property(object):
logging.getLogger(_get_instance_name(instance)).setLevel(value == 'debug' and logging.DEBUG or logging.INFO)
else:
logging.getLogger(_get_instance_name(instance)).setLevel(logging.NOTSET)
-
-
diff --git a/lib/sqlalchemy/mods/legacy_session.py b/lib/sqlalchemy/mods/legacy_session.py
index a28cd3dac7..e21a5634b3 100644
--- a/lib/sqlalchemy/mods/legacy_session.py
+++ b/lib/sqlalchemy/mods/legacy_session.py
@@ -1,4 +1,4 @@
-"""a plugin that emulates 0.1 Session behavior."""
+"""A plugin that emulates 0.1 Session behavior."""
import sqlalchemy.orm.objectstore as objectstore
import sqlalchemy.orm.unitofwork as unitofwork
@@ -14,6 +14,7 @@ class LegacySession(objectstore.Session):
self.begin_count = 0
self.nest_on = util.to_list(nest_on)
self.__pushed_count = 0
+
def was_pushed(self):
if self.nest_on is None:
return
@@ -21,6 +22,7 @@ class LegacySession(objectstore.Session):
if self.__pushed_count == 1:
for n in self.nest_on:
n.push_session()
+
def was_popped(self):
if self.nest_on is None or self.__pushed_count == 0:
return
@@ -28,46 +30,74 @@ class LegacySession(objectstore.Session):
if self.__pushed_count == 0:
for n in self.nest_on:
n.pop_session()
+
class SessionTrans(object):
- """returned by Session.begin(), denotes a transactionalized UnitOfWork instance.
- call commit() on this to commit the transaction."""
+ """Returned by ``Session.begin()``, denotes a
+ transactionalized UnitOfWork instance. Call ``commit()`
+ on this to commit the transaction.
+ """
+
def __init__(self, parent, uow, isactive):
self.__parent = parent
self.__isactive = isactive
self.__uow = uow
+
isactive = property(lambda s:s.__isactive, doc="True if this SessionTrans is the 'active' transaction marker, else its a no-op.")
- parent = property(lambda s:s.__parent, doc="returns the parent Session of this SessionTrans object.")
- uow = property(lambda s:s.__uow, doc="returns the parent UnitOfWork corresponding to this transaction.")
+ parent = property(lambda s:s.__parent, doc="The parent Session of this SessionTrans object.")
+ uow = property(lambda s:s.__uow, doc="The parent UnitOfWork corresponding to this transaction.")
+
def begin(self):
- """calls begin() on the underlying Session object, returning a new no-op SessionTrans object."""
+ """Call ``begin()`` on the underlying ``Session`` object,
+ returning a new no-op ``SessionTrans`` object.
+ """
+
if self.parent.uow is not self.uow:
raise InvalidRequestError("This SessionTrans is no longer valid")
return self.parent.begin()
+
def commit(self):
- """commits the transaction noted by this SessionTrans object."""
+ """Commit the transaction noted by this ``SessionTrans`` object."""
+
self.__parent._trans_commit(self)
self.__isactive = False
+
def rollback(self):
- """rolls back the current UnitOfWork transaction, in the case that begin()
- has been called. The changes logged since the begin() call are discarded."""
+ """Roll back the current UnitOfWork transaction, in the
+ case that ``begin()`` has been called.
+
+ The changes logged since the begin() call are discarded.
+ """
+
self.__parent._trans_rollback(self)
self.__isactive = False
+
def begin(self):
- """begins a new UnitOfWork transaction and returns a tranasaction-holding
- object. commit() or rollback() should be called on the returned object.
- commit() on the Session will do nothing while a transaction is pending, and further
- calls to begin() will return no-op transactional objects."""
+ """Begin a new UnitOfWork transaction and return a
+ transaction-holding object.
+
+ ``commit()`` or ``rollback()`` should be called on the returned object.
+
+ ``commit()`` on the ``Session`` will do nothing while a
+ transaction is pending, and further calls to ``begin()`` will
+ return no-op transactional objects.
+ """
+
if self.parent_uow is not None:
return LegacySession.SessionTrans(self, self.uow, False)
self.parent_uow = self.uow
self.uow = unitofwork.UnitOfWork(identity_map = self.uow.identity_map)
return LegacySession.SessionTrans(self, self.uow, True)
+
def commit(self, *objects):
- """commits the current UnitOfWork transaction. called with
- no arguments, this is only used
- for "implicit" transactions when there was no begin().
- if individual objects are submitted, then only those objects are committed, and the
- begin/commit cycle is not affected."""
+ """Commit the current UnitOfWork transaction.
+
+ Called with no arguments, this is only used for *implicit*
+ transactions when there was no ``begin()``.
+
+ If individual objects are submitted, then only those objects
+ are committed, and the begin/commit cycle is not affected.
+ """
+
# if an object list is given, commit just those but dont
# change begin/commit status
if len(objects):
@@ -76,6 +106,7 @@ class LegacySession(objectstore.Session):
return
if self.parent_uow is None:
self._commit_uow()
+
def _trans_commit(self, trans):
if trans.uow is self.uow and trans.isactive:
try:
@@ -83,10 +114,12 @@ class LegacySession(objectstore.Session):
finally:
self.uow = self.parent_uow
self.parent_uow = None
+
def _trans_rollback(self, trans):
if trans.uow is self.uow:
self.uow = self.parent_uow
self.parent_uow = None
+
def _commit_uow(self, *obj):
self.was_pushed()
try:
@@ -95,11 +128,13 @@ class LegacySession(objectstore.Session):
self.was_popped()
def begin():
- """deprecated. use s = Session(new_imap=False)."""
+ """Deprecated. Use ``s = Session(new_imap=False)``."""
+
return objectstore.get_session().begin()
def commit(*obj):
- """deprecated; use flush(*obj)"""
+ """Deprecated. Use ``flush(*obj)``."""
+
objectstore.get_session().flush(*obj)
def uow():
@@ -137,4 +172,5 @@ def install_plugin():
objectstore.push_session = push_session
objectstore.pop_session = pop_session
objectstore.using_session = using_session
+
install_plugin()
diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py
index 51ed6e4a57..ac8de9b063 100644
--- a/lib/sqlalchemy/mods/selectresults.py
+++ b/lib/sqlalchemy/mods/selectresults.py
@@ -1,7 +1,7 @@
from sqlalchemy.ext.selectresults import *
from sqlalchemy.orm.mapper import global_extensions
-
def install_plugin():
global_extensions.append(SelectResultsExt)
+
install_plugin()
diff --git a/lib/sqlalchemy/mods/threadlocal.py b/lib/sqlalchemy/mods/threadlocal.py
index 6fce859543..c8043bc624 100644
--- a/lib/sqlalchemy/mods/threadlocal.py
+++ b/lib/sqlalchemy/mods/threadlocal.py
@@ -1,15 +1,22 @@
-"""this plugin installs thread-local behavior at the Engine and Session level.
+"""This plugin installs thread-local behavior at the ``Engine`` and ``Session`` level.
-The default Engine strategy will be "threadlocal", producing TLocalEngine instances for create_engine by default.
-With this engine, connect() method will return the same connection on the same thread, if it is already checked out
-from the pool. this greatly helps functions that call multiple statements to be able to easily use just one connection
-without explicit "close" statements on result handles.
+The default ``Engine`` strategy will be *threadlocal*, producing
+``TLocalEngine`` instances for create_engine by default.
-on the Session side, module-level methods will be installed within the objectstore module, such as flush(), delete(), etc.
-which call this method on the thread-local session.
+With this engine, ``connect()`` method will return the same connection
+on the same thread, if it is already checked out from the pool. This
+greatly helps functions that call multiple statements to be able to
+easily use just one connection without explicit ``close`` statements
+on result handles.
-Note: this mod creates a global, thread-local session context named sqlalchemy.objectstore. All mappers created
-while this mod is installed will reference this global context when creating new mapped object instances.
+On the ``Session`` side, module-level methods will be installed within
+the objectstore module, such as ``flush()``, ``delete()``, etc. which
+call this method on the thread-local session.
+
+Note: this mod creates a global, thread-local session context named
+``sqlalchemy.objectstore``. All mappers created while this mod is
+installed will reference this global context when creating new mapped
+object instances.
"""
from sqlalchemy import util, engine, mapper
@@ -20,7 +27,6 @@ from sqlalchemy.orm.session import Session
import sqlalchemy
import sys, types
-
__all__ = ['Objectstore', 'assign_mapper']
class Objectstore(object):
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 30070e8d72..4c87c4bdcd 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -5,9 +5,10 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
-the mapper package provides object-relational functionality, building upon the schema and sql
+The mapper package provides object-relational functionality, building upon the schema and sql
packages and tying operations to class properties and constructors.
"""
+
from sqlalchemy import exceptions
from sqlalchemy import util as sautil
from sqlalchemy.orm.mapper import *
@@ -18,15 +19,17 @@ from sqlalchemy.orm import properties, strategies, interfaces
from sqlalchemy.orm.session import Session as create_session
from sqlalchemy.orm.session import object_session, attribute_manager
-__all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'extension',
- 'mapper', 'clear_mappers', 'compile_mappers', 'clear_mapper', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query',
+__all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer', 'extension',
+ 'mapper', 'clear_mappers', 'compile_mappers', 'clear_mapper', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query',
'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'contains_alias', 'contains_eager', 'EXT_PASS', 'object_session'
]
def relation(*args, **kwargs):
- """provide a relationship of a primary Mapper to a secondary Mapper.
-
- This corresponds to a parent-child or associative table relationship."""
+ """Provide a relationship of a primary Mapper to a secondary Mapper.
+
+ This corresponds to a parent-child or associative table relationship.
+ """
+
if len(args) > 1 and isinstance(args[0], type):
raise exceptions.ArgumentError("relation(class, table, **kwargs) is deprecated. Please use relation(class, **kwargs) or relation(mapper, **kwargs).")
return _relation_loader(*args, **kwargs)
@@ -35,96 +38,123 @@ def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=Non
return properties.PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, lazy=lazy, **kwargs)
def backref(name, **kwargs):
- """create a BackRef object with explicit arguments, which are the same arguments one
- can send to relation().
-
- used with the "backref" keyword argument to relation() in place
- of a string argument. """
+ """Create a BackRef object with explicit arguments, which are the same arguments one
+ can send to ``relation()``.
+
+ Used with the `backref` keyword argument to ``relation()`` in
+ place of a string argument.
+ """
+
return properties.BackRef(name, **kwargs)
-
+
def deferred(*columns, **kwargs):
- """return a DeferredColumnProperty, which indicates this object attributes should only be loaded
- from its corresponding table column when first accessed.
-
- used with the 'properties' dictionary sent to mapper()."""
+ """Return a ``DeferredColumnProperty``, which indicates this
+ object attributes should only be loaded from its corresponding
+ table column when first accessed.
+
+ Used with the `properties` dictionary sent to ``mapper()``.
+ """
+
return properties.ColumnProperty(deferred=True, *columns, **kwargs)
-
+
def mapper(class_, table=None, *args, **params):
- """return a new Mapper object.
-
- See the Mapper class for a description of arguments."""
+ """Return a new ``Mapper`` object.
+
+ See the ``Mapper`` class for a description of arguments.
+ """
+
return Mapper(class_, table, *args, **params)
def synonym(name, proxy=False):
- """set up 'name' as a synonym to another MapperProperty.
-
- Used with the 'properties' dictionary sent to mapper()."""
+ """Set up `name` as a synonym to another ``MapperProperty``.
+
+ Used with the `properties` dictionary sent to ``mapper()``.
+ """
+
return interfaces.SynonymProperty(name, proxy=proxy)
def compile_mappers():
- """compile all mappers that have been defined.
-
- this is equivalent to calling compile() on any individual mapper."""
+ """Compile all mappers that have been defined.
+
+ This is equivalent to calling ``compile()` on any individual mapper.
+ """
+
if not len(mapper_registry):
return
mapper_registry.values()[0].compile()
-
+
def clear_mappers():
- """remove all mappers that have been created thus far.
-
- when new mappers are created, they will be assigned to their classes as their primary mapper."""
+ """Remove all mappers that have been created thus far.
+
+ When new mappers are created, they will be assigned to their
+ classes as their primary mapper.
+ """
+
for mapper in mapper_registry.values():
attribute_manager.reset_class_managed(mapper.class_)
if hasattr(mapper.class_, 'c'):
del mapper.class_.c
mapper_registry.clear()
sautil.ArgSingleton.instances.clear()
-
+
def clear_mapper(m):
- """remove the given mapper from the storage of mappers.
-
- when a new mapper is created for the previous mapper's class, it will be used as that classes'
- new primary mapper."""
+ """Remove the given mapper from the storage of mappers.
+
+ When a new mapper is created for the previous mapper's class, it
+ will be used as that classes' new primary mapper.
+ """
+
del mapper_registry[m.class_key]
attribute_manager.reset_class_managed(m.class_)
if hasattr(m.class_, 'c'):
del m.class_.c
m.class_key.dispose()
-
+
def extension(ext):
- """return a MapperOption that will insert the given MapperExtension to the
- beginning of the list of extensions that will be called in the context of the Query.
-
- used with query.options()."""
+ """Return a ``MapperOption`` that will insert the given
+ ``MapperExtension`` to the beginning of the list of extensions
+ that will be called in the context of the ``Query``.
+
+ Used with ``query.options()``.
+ """
+
return ExtensionOption(ext)
-
+
def eagerload(name):
- """return a MapperOption that will convert the property of the given name
- into an eager load.
-
- used with query.options()."""
+ """Return a ``MapperOption`` that will convert the property of the
+ given name into an eager load.
+
+ Used with ``query.options()``.
+ """
+
return strategies.EagerLazyOption(name, lazy=False)
def lazyload(name):
- """return a MapperOption that will convert the property of the given name
- into a lazy load.
-
- used with query.options()."""
+ """Return a ``MapperOption`` that will convert the property of the
+ given name into a lazy load.
+
+ Used with ``query.options()``.
+ """
+
return strategies.EagerLazyOption(name, lazy=True)
def noload(name):
- """return a MapperOption that will convert the property of the given name
- into a non-load.
-
- used with query.options()."""
+ """Return a ``MapperOption`` that will convert the property of the
+ given name into a non-load.
+
+ Used with ``query.options()``.
+ """
+
return strategies.EagerLazyOption(name, lazy=None)
def contains_alias(alias):
- """return a MapperOption that will indicate to the query that the main table
- has been aliased.
-
- "alias" is the string name or Alias object representing the alias.
+ """Return a ``MapperOption`` that will indicate to the query that
+ the main table has been aliased.
+
+ `alias` is the string name or ``Alias`` object representing the
+ alias.
"""
+
class AliasedRow(MapperExtension):
def __init__(self, alias):
self.alias = alias
@@ -144,53 +174,66 @@ def contains_alias(alias):
if c2 and row.has_key(c2):
newrow[c] = row[c2]
return newrow
-
+
return ExtensionOption(AliasedRow(alias))
-
+
def contains_eager(key, alias=None, decorator=None):
- """return a MapperOption that will indicate to the query that the given
- attribute will be eagerly loaded.
-
- used when feeding SQL result sets directly into
- query.instances(). Also bundles an EagerLazyOption to turn on eager loading
- in case it isnt already.
-
- "alias" is the string name of an alias, *or* an sql.Alias object, which represents
- the aliased columns in the query. this argument is optional.
-
- "decorator" is mutually exclusive of "alias" and is a row-processing function which
- will be applied to the incoming row before sending to the eager load handler. use this
- for more sophisticated row adjustments beyond a straight alias."""
+ """Return a ``MapperOption`` that will indicate to the query that
+ the given attribute will be eagerly loaded.
+
+ Used when feeding SQL result sets directly into
+ ``query.instances()``. Also bundles an ``EagerLazyOption`` to
+ turn on eager loading in case it isnt already.
+
+ `alias` is the string name of an alias, **or** an ``sql.Alias``
+ object, which represents the aliased columns in the query. This
+ argument is optional.
+
+ `decorator` is mutually exclusive of `alias` and is a
+ row-processing function which will be applied to the incoming row
+ before sending to the eager load handler. use this for more
+ sophisticated row adjustments beyond a straight alias.
+ """
+
return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, alias=alias, decorator=decorator))
-
+
def defer(name):
- """return a MapperOption that will convert the column property of the given
- name into a deferred load.
-
- used with query.options()"""
+ """Return a ``MapperOption`` that will convert the column property
+ of the given name into a deferred load.
+
+ Used with ``query.options()``"""
return strategies.DeferredOption(name, defer=True)
+
def undefer(name):
- """return a MapperOption that will convert the column property of the given
- name into a non-deferred (regular column) load.
-
- used with query.options()."""
+ """Return a ``MapperOption`` that will convert the column property
+ of the given name into a non-deferred (regular column) load.
+
+ Used with ``query.options()``.
+ """
+
return strategies.DeferredOption(name, defer=False)
-
+
def cascade_mappers(*classes_or_mappers):
- """attempt to create a series of relations() between mappers automatically, via
- introspecting the foreign key relationships of the underlying tables.
-
- given a list of classes and/or mappers, identifies the foreign key relationships
- between the given mappers or corresponding class mappers, and creates relation()
- objects representing those relationships, including a backreference. Attempts to find
- the "secondary" table in a many-to-many relationship as well. The names of the relations
- will be a lowercase version of the related class. In the case of one-to-many or many-to-many,
- the name will be "pluralized", which currently is based on the English language (i.e. an 's' or
- 'es' added to it).
-
- NOTE: this method usually works poorly, and its usage is generally not advised.
+ """Attempt to create a series of ``relations()`` between mappers
+ automatically, via introspecting the foreign key relationships of
+ the underlying tables.
+
+ Given a list of classes and/or mappers, identify the foreign key
+ relationships between the given mappers or corresponding class
+ mappers, and create ``relation()`` objects representing those
+ relationships, including a backreference. Attempt to find the
+ *secondary* table in a many-to-many relationship as well.
+
+ The names of the relations will be a lowercase version of the
+ related class. In the case of one-to-many or many-to-many, the
+ name will be *pluralized*, which currently is based on the English
+ language (i.e. an 's' or 'es' added to it).
+
+ NOTE: this method usually works poorly, and its usage is generally
+ not advised.
"""
+
table_to_mapper = {}
for item in classes_or_mappers:
if isinstance(item, Mapper):
@@ -199,12 +242,14 @@ def cascade_mappers(*classes_or_mappers):
klass = item
m = class_mapper(klass)
table_to_mapper[m.mapped_table] = m
+
def pluralize(name):
# oh crap, do we need locale stuff now
if name[-1] == 's':
return name + "es"
else:
return name + "s"
+
for table,mapper in table_to_mapper.iteritems():
for fk in table.foreign_keys:
if fk.column.table is table:
@@ -229,5 +274,3 @@ def cascade_mappers(*classes_or_mappers):
propname = m2.class_.__name__.lower()
propname2 = pluralize(mapper.class_.__name__.lower())
mapper.add_property(propname, relation(m2, secondary=secondary, backref=propname2))
-
-
\ No newline at end of file
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 6b2a1ee1a5..af3487dfdb 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -10,11 +10,14 @@ from sqlalchemy import logging, exceptions
import weakref
class InstrumentedAttribute(object):
- """a property object that instruments attribute access on object instances. All methods correspond to
- a single attribute on a particular class."""
-
+ """A property object that instruments attribute access on object instances.
+
+ All methods correspond to a single attribute on a particular
+ class.
+ """
+
PASSIVE_NORESULT = object()
-
+
def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
self.manager = manager
self.key = key
@@ -40,8 +43,10 @@ class InstrumentedAttribute(object):
def __set__(self, obj, value):
self.set(None, obj, value)
+
def __delete__(self, obj):
self.delete(None, obj)
+
def __get__(self, obj, owner):
if obj is None:
return self
@@ -57,71 +62,92 @@ class InstrumentedAttribute(object):
return False
else:
return False
-
-
+
def hasparent(self, item, optimistic=False):
- """return the boolean value of a "hasparent" flag attached to the given item.
-
- the 'optimistic' flag determines what the default return value should be if
- no "hasparent" flag can be located. as this function is used to determine if
- an instance is an "orphan", instances that were loaded from storage should be assumed
- to not be orphans, until a True/False value for this flag is set. an instance attribute
- that is loaded by a callable function will also not have a "hasparent" flag.
+ """Return the boolean value of a `hasparent` flag attached to the given item.
+
+ The `optimistic` flag determines what the default return value
+ should be if no `hasparent` flag can be located.
+
+ As this function is used to determine if an instance is an
+ *orphan*, instances that were loaded from storage should be
+ assumed to not be orphans, until a True/False value for this
+ flag is set.
+
+ An instance attribute that is loaded by a callable function
+ will also not have a `hasparent` flag.
"""
+
return item._state.get(('hasparent', id(self)), optimistic)
-
+
def sethasparent(self, item, value):
- """sets a boolean flag on the given item corresponding to whether or not it is
- attached to a parent object via the attribute represented by this InstrumentedAttribute."""
+ """Set a boolean flag on the given item corresponding to
+ whether or not it is attached to a parent object via the
+ attribute represented by this ``InstrumentedAttribute``.
+ """
+
item._state[('hasparent', id(self))] = value
-
+
def get_history(self, obj, passive=False):
- """return a new AttributeHistory object for the given object/this attribute's key.
-
- if passive is True, then dont execute any callables; if the attribute's value
- can only be achieved via executing a callable, then return None."""
+ """Return a new ``AttributeHistory`` object for the given object/this attribute's key.
+
+ If `passive` is True, then don't execute any callables; if the
+ attribute's value can only be achieved via executing a
+ callable, then return None.
+ """
+
# get the current state. this may trigger a lazy load if
- # passive is False.
+ # passive is False.
current = self.get(obj, passive=passive, raiseerr=False)
if current is InstrumentedAttribute.PASSIVE_NORESULT:
return None
return AttributeHistory(self, obj, current, passive=passive)
def set_callable(self, obj, callable_):
- """set a callable function for this attribute on the given object.
-
- this callable will be executed when the attribute is next accessed,
- and is assumed to construct part of the instances previously stored state. When
- its value or values are loaded, they will be established as part of the
- instance's "committed state". while "trackparent" information will be assembled
- for these instances, attribute-level event handlers will not be fired.
-
- the callable overrides the class level callable set in the InstrumentedAttribute
- constructor.
+ """Set a callable function for this attribute on the given object.
+
+ This callable will be executed when the attribute is next
+ accessed, and is assumed to construct part of the instances
+ previously stored state. When its value or values are loaded,
+ they will be established as part of the instance's *committed
+ state*. While *trackparent* information will be assembled for
+ these instances, attribute-level event handlers will not be
+ fired.
+
+ The callable overrides the class level callable set in the
+ ``InstrumentedAttribute` constructor.
"""
+
if callable_ is None:
self.initialize(obj)
else:
obj._state[('callable', self)] = callable_
def reset(self, obj):
- """removes any per-instance callable functions corresponding to this InstrumentedAttribute's attribute
- from the given object, and removes this InstrumentedAttribute's
- attribute from the given object's dictionary."""
+ """Remove any per-instance callable functions corresponding to
+ this ``InstrumentedAttribute``'s attribute from the given
+ object, and remove this ``InstrumentedAttribute``'s attribute
+ from the given object's dictionary.
+ """
+
try:
del obj._state[('callable', self)]
except KeyError:
pass
self.clear(obj)
-
+
def clear(self, obj):
- """removes this InstrumentedAttribute's attribute from the given object's dictionary. subsequent calls to
- getattr(obj, key) will raise an AttributeError by default."""
+ """Remove this ``InstrumentedAttribute``'s attribute from the given object's dictionary.
+
+ Subsequent calls to ``getattr(obj, key)`` will raise an
+ ``AttributeError`` by default.
+ """
+
try:
del obj.__dict__[self.key]
except KeyError:
pass
-
+
def _get_callable(self, obj):
if obj._state.has_key(('callable', self)):
return obj._state[('callable', self)]
@@ -129,7 +155,7 @@ class InstrumentedAttribute(object):
return self.callable_(obj)
else:
return None
-
+
def _blank_list(self):
if self.typecallable is not None:
return self.typecallable()
@@ -144,12 +170,15 @@ class InstrumentedAttribute(object):
return t
else:
return data
-
+
def initialize(self, obj):
- """initialize this attribute on the given object instance.
-
- if this is a list-based attribute, a new, blank list will be created.
- if a scalar attribute, the value will be initialized to None."""
+ """Initialize this attribute on the given object instance.
+
+ If this is a list-based attribute, a new, blank list will be
+ created. if a scalar attribute, the value will be initialized
+ to None.
+ """
+
if self.uselist:
l = InstrumentedList(self, obj, self._blank_list())
obj.__dict__[self.key] = l
@@ -157,11 +186,16 @@ class InstrumentedAttribute(object):
else:
obj.__dict__[self.key] = None
return None
-
+
def get(self, obj, passive=False, raiseerr=True):
- """retrieves a value from the given object. if a callable is assembled
- on this object's attribute, and passive is False, the callable will be executed
- and the resulting value will be set as the new value for this attribute."""
+ """Retrieve a value from the given object.
+
+ If a callable is assembled on this object's attribute, and
+ passive is False, the callable will be executed and the
+ resulting value will be set as the new value for this
+ attribute.
+ """
+
try:
return obj.__dict__[self.key]
except KeyError:
@@ -173,7 +207,7 @@ class InstrumentedAttribute(object):
del state['trigger']
trig()
return self.get(obj, passive=passive, raiseerr=raiseerr)
-
+
if self.uselist:
callable_ = self._get_callable(obj)
if callable_ is not None:
@@ -182,13 +216,13 @@ class InstrumentedAttribute(object):
self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
values = callable_()
l = InstrumentedList(self, obj, self._adapt_list(values), init=False)
-
+
# if a callable was executed, then its part of the "committed state"
# if any, so commit the newly loaded data
orig = state.get('original', None)
if orig is not None:
orig.commit_attribute(self, obj, l)
-
+
else:
# note that we arent raising AttributeErrors, just creating a new
# blank list and setting it.
@@ -215,11 +249,15 @@ class InstrumentedAttribute(object):
# note that we arent raising AttributeErrors, just returning None.
# this might be a good thing to be changeable by options.
return None
-
+
def set(self, event, obj, value):
- """sets a value on the given object. 'event' is the InstrumentedAttribute that
- initiated the set() operation and is used to control the depth of a circular setter
- operation."""
+ """Set a value on the given object.
+
+ `event` is the ``InstrumentedAttribute`` that initiated the
+ ``set()` operation and is used to control the depth of a
+ circular setter operation.
+ """
+
if event is not self:
state = obj._state
# if an instance-wide "trigger" was set, call that
@@ -245,9 +283,13 @@ class InstrumentedAttribute(object):
old.list_replaced()
def delete(self, event, obj):
- """deletes a value from the given object. 'event' is the InstrumentedAttribute that
- initiated the delete() operation and is used to control the depth of a circular delete
- operation."""
+ """Delete a value from the given object.
+
+ `event` is the ``InstrumentedAttribute`` that initiated the
+ ``delete()`` operation and is used to control the depth of a
+ circular delete operation.
+ """
+
if event is not self:
try:
if not self.uselist and (self.trackparent or len(self.extensions)):
@@ -265,10 +307,17 @@ class InstrumentedAttribute(object):
ext.delete(event or self, obj, old)
def append(self, event, obj, value):
- """appends an element to a list based element or sets a scalar based element to the given value.
- Used by GenericBackrefExtension to "append" an item independent of list/scalar semantics.
- 'event' is the InstrumentedAttribute that initiated the append() operation and is used to control
- the depth of a circular append operation."""
+ """Append an element to a list based element or sets a scalar
+ based element to the given value.
+
+ Used by ``GenericBackrefExtension`` to *append* an item
+ independent of list/scalar semantics.
+
+ `event` is the ``InstrumentedAttribute`` that initiated the
+ ``append()`` operation and is used to control the depth of a
+ circular append operation.
+ """
+
if self.uselist:
if event is not self:
self.get(obj).append_with_event(value, event)
@@ -276,10 +325,17 @@ class InstrumentedAttribute(object):
self.set(event, obj, value)
def remove(self, event, obj, value):
- """removes an element from a list based element or sets a scalar based element to None.
- Used by GenericBackrefExtension to "remove" an item independent of list/scalar semantics.
- 'event' is the InstrumentedAttribute that initiated the remove() operation and is used to control
- the depth of a circular remove operation."""
+ """Remove an element from a list based element or sets a
+ scalar based element to None.
+
+ Used by ``GenericBackrefExtension`` to *remove* an item
+ independent of list/scalar semantics.
+
+ `event` is the ``InstrumentedAttribute`` that initiated the
+ ``remove()`` operation and is used to control the depth of a
+ circular remove operation.
+ """
+
if self.uselist:
if event is not self:
self.get(obj).remove_with_event(value, event)
@@ -287,33 +343,45 @@ class InstrumentedAttribute(object):
self.set(event, obj, None)
def append_event(self, event, obj, value):
- """called by InstrumentedList when an item is appended"""
+ """Called by ``InstrumentedList`` when an item is appended."""
+
obj._state['modified'] = True
if self.trackparent and value is not None:
self.sethasparent(value, True)
for ext in self.extensions:
ext.append(event or self, obj, value)
-
+
def remove_event(self, event, obj, value):
- """called by InstrumentedList when an item is removed"""
+ """Called by ``InstrumentedList`` when an item is removed."""
+
obj._state['modified'] = True
if self.trackparent and value is not None:
self.sethasparent(value, False)
for ext in self.extensions:
ext.delete(event or self, obj, value)
+
InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute)
-
+
class InstrumentedList(object):
- """instruments a list-based attribute. all mutator operations (i.e. append, remove, etc.) will fire off events to the
- InstrumentedAttribute that manages the object's attribute. those events in turn trigger things like
- backref operations and whatever is implemented by do_list_value_changed on InstrumentedAttribute.
-
- note that this list does a lot less than earlier versions of SA list-based attributes, which used HistoryArraySet.
- this list wrapper does *not* maintain setlike semantics, meaning you can add as many duplicates as
- you want (which can break a lot of SQL), and also does not do anything related to history tracking.
-
- Please see ticket #213 for information on the future of this class, where it will be broken out into more
- collection-specific subtypes."""
+ """Instrument a list-based attribute.
+
+ All mutator operations (i.e. append, remove, etc.) will fire off
+ events to the ``InstrumentedAttribute`` that manages the object's
+ attribute. Those events in turn trigger things like backref
+ operations and whatever is implemented by
+ ``do_list_value_changed`` on ``InstrumentedAttribute``.
+
+ Note that this list does a lot less than earlier versions of SA
+ list-based attributes, which used ``HistoryArraySet``. This list
+ wrapper does **not** maintain setlike semantics, meaning you can add
+ as many duplicates as you want (which can break a lot of SQL), and
+ also does not do anything related to history tracking.
+
+ Please see ticket #213 for information on the future of this
+ class, where it will be broken out into more collection-specific
+ subtypes.
+ """
+
def __init__(self, attr, obj, data, init=True):
self.attr = attr
# this weakref is to prevent circular references between the parent object
@@ -321,9 +389,9 @@ class InstrumentedList(object):
self.__obj = weakref.ref(obj)
self.key = attr.key
self.data = data or attr._blank_list()
-
+
# adapt to lists or sets
- # TODO: make three subclasses of InstrumentedList that come off from a
+ # TODO: make three subclasses of InstrumentedList that come off from a
# metaclass, based on the type of data sent in
if hasattr(self.data, 'append'):
self._data_appender = self.data.append
@@ -335,60 +403,82 @@ class InstrumentedList(object):
raise exceptions.ArgumentError("Collection type " + repr(type(self.data)) + " has no append() or add() method")
if isinstance(self.data, dict):
self._clear_data = self._clear_dict
-
+
if init:
for x in self.data:
self.__setrecord(x)
def list_replaced(self):
- """fires off delete event handlers for each item in the list but
- doesnt affect the original data list"""
+ """Fire off delete event handlers for each item in the list
+ but doesnt affect the original data list.
+ """
+
[self.__delrecord(x) for x in self.data]
def clear(self):
- """clears all items in this InstrumentedList and fires off delete event handlers for each item"""
+ """Clear all items in this InstrumentedList and fires off
+ delete event handlers for each item.
+ """
+
self._clear_data()
+
def _clear_dict(self):
[self.__delrecord(x) for x in self.data.values()]
self.data.clear()
+
def _clear_set(self):
[self.__delrecord(x) for x in self.data]
self.data.clear()
+
def _clear_list(self):
self[:] = []
-
+
def __getstate__(self):
- """implemented to allow pickling, since __obj is a weakref, also the InstrumentedAttribute has callables
- attached to it"""
+ """Implemented to allow pickling, since `__obj` is a weakref,
+ also the ``InstrumentedAttribute`` has callables attached to
+ it.
+ """
+
return {'key':self.key, 'obj':self.obj, 'data':self.data}
+
def __setstate__(self, d):
- """implemented to allow pickling, since __obj is a weakref, also the InstrumentedAttribute has callables
- attached to it"""
+ """Implemented to allow pickling, since `__obj` is a weakref,
+ also the ``InstrumentedAttribute`` has callables attached to it.
+ """
+
self.key = d['key']
self.__obj = weakref.ref(d['obj'])
self.data = d['data']
self.attr = getattr(d['obj'].__class__, self.key)
-
+
obj = property(lambda s:s.__obj())
def unchanged_items(self):
- """deprecated"""
+ """Deprecated."""
+
return self.attr.get_history(self.obj).unchanged_items
+
def added_items(self):
- """deprecated"""
+ """Deprecated."""
+
return self.attr.get_history(self.obj).added_items
+
def deleted_items(self):
- """deprecated"""
+ """Deprecated."""
+
return self.attr.get_history(self.obj).deleted_items
def __iter__(self):
return iter(self.data)
+
def __repr__(self):
return repr(self.data)
-
+
def __getattr__(self, attr):
- """proxies unknown methods and attributes to the underlying
- data array. this allows custom list classes to be used."""
+ """Proxie unknown methods and attributes to the underlying
+ data array. This allows custom list classes to be used.
+ """
+
return getattr(self.data, attr)
def __setrecord(self, item, event=None):
@@ -398,36 +488,41 @@ class InstrumentedList(object):
def __delrecord(self, item, event=None):
self.attr.remove_event(event, self.obj, item)
return True
-
+
def append_with_event(self, item, event):
self.__setrecord(item, event)
self._data_appender(item)
-
+
def append_without_event(self, item):
self._data_appender(item)
-
+
def remove_with_event(self, item, event):
self.__delrecord(item, event)
self.data.remove(item)
-
+
def append(self, item, _mapper_nohistory=False):
- """fires off dependent events, and appends the given item to the underlying list.
- _mapper_nohistory is a backwards compatibility hack; call append_without_event instead."""
+ """Fire off dependent events, and appends the given item to the underlying list.
+
+ `_mapper_nohistory` is a backwards compatibility hack; call
+ ``append_without_event`` instead.
+ """
+
if _mapper_nohistory:
self.append_without_event(item)
else:
self.__setrecord(item)
self._data_appender(item)
-
-
+
def __getitem__(self, i):
return self.data[i]
+
def __setitem__(self, i, item):
if isinstance(i, slice):
self.__setslice__(i.start, i.stop, item)
else:
self.__setrecord(item)
self.data[i] = item
+
def __delitem__(self, i):
if isinstance(i, slice):
self.__delslice__(i.start, i.stop)
@@ -436,65 +531,92 @@ class InstrumentedList(object):
del self.data[i]
def __lt__(self, other): return self.data < self.__cast(other)
+
def __le__(self, other): return self.data <= self.__cast(other)
+
def __eq__(self, other): return self.data == self.__cast(other)
+
def __ne__(self, other): return self.data != self.__cast(other)
+
def __gt__(self, other): return self.data > self.__cast(other)
+
def __ge__(self, other): return self.data >= self.__cast(other)
+
def __cast(self, other):
if isinstance(other, InstrumentedList): return other.data
else: return other
+
def __cmp__(self, other):
return cmp(self.data, self.__cast(other))
+
def __contains__(self, item): return item in self.data
+
def __len__(self): return len(self.data)
+
def __setslice__(self, i, j, other):
i = max(i, 0); j = max(j, 0)
[self.__delrecord(x) for x in self.data[i:]]
g = [a for a in list(other) if self.__setrecord(a)]
self.data[i:] = g
+
def __delslice__(self, i, j):
i = max(i, 0); j = max(j, 0)
for a in self.data[i:j]:
self.__delrecord(a)
del self.data[i:j]
- def insert(self, i, item):
+
+ def insert(self, i, item):
if self.__setrecord(item):
self.data.insert(i, item)
+
def pop(self, i=-1):
item = self.data[i]
self.__delrecord(item)
return self.data.pop(i)
- def remove(self, item):
+
+ def remove(self, item):
self.__delrecord(item)
self.data.remove(item)
+
def extend(self, item_list):
for item in item_list:
- self.append(item)
+ self.append(item)
+
def __add__(self, other):
raise NotImplementedError()
+
def __radd__(self, other):
raise NotImplementedError()
+
def __iadd__(self, other):
raise NotImplementedError()
class AttributeExtension(object):
- """an abstract class which specifies "append", "delete", and "set"
- event handlers to be attached to an object property."""
+ """An abstract class which specifies `append`, `delete`, and `set`
+ event handlers to be attached to an object property.
+ """
+
def append(self, event, obj, child):
pass
+
def delete(self, event, obj, child):
pass
+
def set(self, event, obj, child, oldchild):
pass
-
+
class GenericBackrefExtension(AttributeExtension):
- """an extension which synchronizes a two-way relationship. A typical two-way
- relationship is a parent object containing a list of child objects, where each
- child object references the parent. The other are two objects which contain
- scalar references to each other."""
+ """An extension which synchronizes a two-way relationship.
+
+ A typical two-way relationship is a parent object containing a
+ list of child objects, where each child object references the
+ parent. The other are two objects which contain scalar references
+ to each other.
+ """
+
def __init__(self, key):
self.key = key
+
def set(self, event, obj, child, oldchild):
if oldchild is child:
return
@@ -502,25 +624,32 @@ class GenericBackrefExtension(AttributeExtension):
getattr(oldchild.__class__, self.key).remove(event, oldchild, obj)
if child is not None:
getattr(child.__class__, self.key).append(event, child, obj)
+
def append(self, event, obj, child):
getattr(child.__class__, self.key).append(event, child, obj)
+
def delete(self, event, obj, child):
getattr(child.__class__, self.key).remove(event, child, obj)
class CommittedState(object):
- """stores the original state of an object when the commit() method on the attribute manager
- is called."""
+ """Store the original state of an object when the ``commit()`
+ method on the attribute manager is called.
+ """
+
NO_VALUE = object()
-
+
def __init__(self, manager, obj):
self.data = {}
for attr in manager.managed_attributes(obj.__class__):
self.commit_attribute(attr, obj)
def commit_attribute(self, attr, obj, value=NO_VALUE):
- """establish the value of attribute 'attr' on instance 'obj' as "committed".
-
- this corresponds to a previously saved state being restored. """
+ """Establish the value of attribute `attr` on instance `obj`
+ as *committed*.
+
+ This corresponds to a previously saved state being restored.
+ """
+
if value is CommittedState.NO_VALUE:
if obj.__dict__.has_key(attr.key):
value = obj.__dict__[attr.key]
@@ -545,18 +674,21 @@ class CommittedState(object):
obj.__dict__[attr.key] = self.data[attr.key]
else:
del obj.__dict__[attr.key]
-
+
def __repr__(self):
return "CommittedState: %s" % repr(self.data)
class AttributeHistory(object):
- """calculates the "history" of a particular attribute on a particular instance, based on the CommittedState
- associated with the instance, if any."""
+ """Calculate the *history* of a particular attribute on a
+ particular instance, based on the ``CommittedState`` associated
+ with the instance, if any.
+ """
+
def __init__(self, attr, obj, current, passive=False):
self.attr = attr
-
+
# get the "original" value. if a lazy load was fired when we got
- # the 'current' value, this "original" was also populated just
+ # the 'current' value, this "original" was also populated just
# now as well (therefore we have to get it second)
orig = obj._state.get('original', None)
if orig is not None:
@@ -581,7 +713,7 @@ class AttributeHistory(object):
self._added_items.append(a)
for a in s:
if a not in self._unchanged_items:
- self._deleted_items.append(a)
+ self._deleted_items.append(a)
else:
if attr.is_equal(current, original):
self._unchanged_items = [current]
@@ -595,26 +727,41 @@ class AttributeHistory(object):
self._deleted_items = []
self._unchanged_items = []
#print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
+
def __iter__(self):
return iter(self._current)
+
def is_modified(self):
return len(self._deleted_items) > 0 or len(self._added_items) > 0
+
def added_items(self):
return self._added_items
+
def unchanged_items(self):
return self._unchanged_items
+
def deleted_items(self):
return self._deleted_items
+
def hasparent(self, obj):
- """deprecated. this should be called directly from the appropriate InstrumentedAttribute object."""
+ """Deprecated. This should be called directly from the
+ appropriate ``InstrumentedAttribute`` object.
+ """
+
return self.attr.hasparent(obj)
-
+
class AttributeManager(object):
- """allows the instrumentation of object attributes. AttributeManager is stateless, but can be
- overridden by subclasses to redefine some of its factory operations. Also be aware AttributeManager
- will cache attributes for a given class, allowing not to determine those for each objects (used
- in managed_attributes() and noninherited_managed_attributes()). This cache is cleared for a given class
- while calling register_attribute(), and can be cleared using clear_attribute_cache()"""
+ """Allow the instrumentation of object attributes.
+
+ ``AttributeManager`` is stateless, but can be overridden by
+ subclasses to redefine some of its factory operations. Also be
+ aware ``AttributeManager`` will cache attributes for a given
+ class, allowing not to determine those for each objects (used in
+ ``managed_attributes()`` and
+ ``noninherited_managed_attributes()``). This cache is cleared for
+ a given class while calling ``register_attribute()``, and can be
+ cleared using ``clear_attribute_cache()``.
+ """
def __init__(self):
# will cache attributes, indexed by class objects
@@ -623,35 +770,45 @@ class AttributeManager(object):
def clear_attribute_cache(self):
self._attribute_cache.clear()
-
+
def rollback(self, *obj):
- """retrieves the committed history for each object in the given list, and rolls back the attributes
- each instance to their original value."""
+ """Retrieve the committed history for each object in the given
+ list, and rolls back the attributes each instance to their
+ original value.
+ """
+
for o in obj:
orig = o._state.get('original')
if orig is not None:
orig.rollback(self, o)
else:
self._clear(o)
-
+
def _clear(self, obj):
for attr in self.managed_attributes(obj.__class__):
try:
del obj.__dict__[attr.key]
except KeyError:
pass
-
+
def commit(self, *obj):
- """creates a CommittedState instance for each object in the given list, representing
- its "unchanged" state, and associates it with the instance. AttributeHistory objects
- will indicate the modified state of instance attributes as compared to its value in this
- CommittedState object."""
+ """Create a ``CommittedState`` instance for each object in the given list, representing
+ its *unchanged* state, and associates it with the instance.
+
+ ``AttributeHistory`` objects will indicate the modified state of
+ instance attributes as compared to its value in this
+ ``CommittedState`` object.
+ """
+
for o in obj:
o._state['original'] = CommittedState(self, o)
o._state['modified'] = False
def managed_attributes(self, class_):
- """returns an iterator of all InstrumentedAttribute objects associated with the given class."""
+ """Return an iterator of all ``InstrumentedAttribute`` objects
+ associated with the given class.
+ """
+
try:
return self._inherited_attribute_cache[class_]
except KeyError:
@@ -676,23 +833,36 @@ class AttributeManager(object):
if attr.check_mutable_modified(object):
return True
return object._state.get('modified', False)
-
+
def init_attr(self, obj):
- """sets up the __sa_attr_state dictionary on the given instance. This dictionary is
- automatically created when the '_state' attribute of the class is first accessed, but calling
- it here will save a single throw of an AttributeError that occurs in that creation step."""
+ """Sets up the __sa_attr_state dictionary on the given instance.
+
+ This dictionary is automatically created when the `_state`
+ attribute of the class is first accessed, but calling it here
+ will save a single throw of an ``AttributeError`` that occurs
+ in that creation step.
+ """
+
setattr(obj, '_%s__sa_attr_state' % obj.__class__.__name__, {})
def get_history(self, obj, key, **kwargs):
- """returns a new AttributeHistory object for the given attribute on the given object."""
+ """Return a new ``AttributeHistory`` object for the given
+ attribute on the given object.
+ """
+
return getattr(obj.__class__, key).get_history(obj, **kwargs)
def get_as_list(self, obj, key, passive=False):
- """returns an attribute of the given name from the given object. if the attribute
- is a scalar, returns it as a single-item list, otherwise returns the list based attribute.
- if the attribute's value is to be produced by an unexecuted callable,
- the callable will only be executed if the given 'passive' flag is False.
+ """Return an attribute of the given name from the given object.
+
+ If the attribute is a scalar, return it as a single-item list,
+ otherwise return the list based attribute.
+
+ If the attribute's value is to be produced by an unexecuted
+ callable, the callable will only be executed if the given
+ `passive` flag is False.
"""
+
attr = getattr(obj.__class__, key)
x = attr.get(obj, passive=passive)
if x is InstrumentedAttribute.PASSIVE_NORESULT:
@@ -701,11 +871,14 @@ class AttributeManager(object):
return x
else:
return [x]
-
+
def trigger_history(self, obj, callable):
- """clears all managed object attributes and places the given callable
- as an attribute-wide "trigger", which will execute upon the next attribute access, after
- which the trigger is removed."""
+ """Clear all managed object attributes and places the given
+ `callable` as an attribute-wide *trigger*, which will execute
+ upon the next attribute access, after which the trigger is
+ removed.
+ """
+
self._clear(obj)
try:
del obj._state['original']
@@ -714,49 +887,73 @@ class AttributeManager(object):
obj._state['trigger'] = callable
def untrigger_history(self, obj):
- """removes a trigger function set by trigger_history. does not restore the previous state of the object."""
+ """Remove a trigger function set by trigger_history.
+
+ Does not restore the previous state of the object.
+ """
+
del obj._state['trigger']
-
+
def has_trigger(self, obj):
- """returns True if the given object has a trigger function set by trigger_history()."""
+ """Return True if the given object has a trigger function set
+ by ``trigger_history()``.
+ """
+
return obj._state.has_key('trigger')
-
+
def reset_instance_attribute(self, obj, key):
- """removes any per-instance callable functions corresponding to given attribute key
- from the given object, and removes this attribute from the given object's dictionary."""
+ """Remove any per-instance callable functions corresponding to
+ given attribute `key` from the given object, and remove this
+ attribute from the given object's dictionary.
+ """
+
attr = getattr(obj.__class__, key)
attr.reset(obj)
-
+
def reset_class_managed(self, class_):
- """removes all InstrumentedAttribute property objects from the given class."""
+ """Remove all ``InstrumentedAttribute`` property objects from
+ the given class.
+ """
+
for attr in self.noninherited_managed_attributes(class_):
delattr(class_, attr.key)
self._inherited_attribute_cache.pop(class_,None)
self._noninherited_attribute_cache.pop(class_,None)
-
+
def is_class_managed(self, class_, key):
- """returns True if the given key correponds to an instrumented property on the given class."""
+ """Return True if the given `key` correponds to an
+ instrumented property on the given class.
+ """
+
return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs):
- """initializes an attribute on an instance to either a blank value, cancelling
- out any class- or instance-level callables that were present, or if a callable
- is supplied sets the callable to be invoked when the attribute is next accessed."""
+ """Initialize an attribute on an instance to either a blank
+ value, cancelling out any class- or instance-level callables
+ that were present, or if a `callable` is supplied set the
+ callable to be invoked when the attribute is next accessed.
+ """
+
getattr(obj.__class__, key).set_callable(obj, callable_)
-
+
def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
- """creates a scalar property object, defaulting to InstrumentedAttribute, which
- will communicate change events back to this AttributeManager."""
+ """Create a scalar property object, defaulting to
+ ``InstrumentedAttribute``, which will communicate change
+ events back to this ``AttributeManager``.
+ """
+
return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs)
-
+
def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
- """registers an attribute at the class level to be instrumented for all instances
- of the class."""
- # firt invalidate the cache for the given class
+ """Register an attribute at the class level to be instrumented
+ for all instances of the class.
+ """
+
+ # firt invalidate the cache for the given class
# (will be reconstituted as needed, while getting managed attributes)
self._inherited_attribute_cache.pop(class_,None)
self._noninherited_attribute_cache.pop(class_,None)
-
+
#print self, "register attribute", key, "for class", class_
if not hasattr(class_, '_state'):
def _get_state(self):
@@ -764,9 +961,8 @@ class AttributeManager(object):
self._sa_attr_state = {}
return self._sa_attr_state
class_._state = property(_get_state)
-
+
typecallable = kwargs.pop('typecallable', None)
if isinstance(typecallable, InstrumentedAttribute):
typecallable = None
setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs))
-
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
index a8d8ad507b..5a756f5d27 100644
--- a/lib/sqlalchemy/orm/dependency.py
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -5,8 +5,10 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""bridges the PropertyLoader (i.e. a relation()) and the UOWTransaction
-together to allow processing of scalar- and list-based dependencies at flush time."""
+"""Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the
+``UOWTransaction`` together to allow processing of scalar- and
+list-based dependencies at flush time.
+"""
from sqlalchemy.orm import sync
from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY
@@ -42,74 +44,109 @@ class DependencyProcessor(object):
self._compile_synchronizers()
def _get_instrumented_attribute(self):
- """return the InstrumentedAttribute handled by this DependencyProecssor"""
+ """Return the ``InstrumentedAttribute`` handled by this
+ ``DependencyProecssor``.
+ """
+
return getattr(self.parent.class_, self.key)
-
+
def register_dependencies(self, uowcommit):
- """tells a UOWTransaction what mappers are dependent on which, with regards
- to the two or three mappers handled by this PropertyLoader.
+ """Tell a ``UOWTransaction`` what mappers are dependent on
+ which, with regards to the two or three mappers handled by
+ this ``PropertyLoader``.
+
+ Also register itself as a *processor* for one of its mappers,
+ which will be executed after that mapper's objects have been
+ saved or before they've been deleted. The process operation
+ manages attributes and dependent operations upon the objects
+ of one of the involved mappers.
+ """
- Also registers itself as a "processor" for one of its mappers, which
- will be executed after that mapper's objects have been saved or before
- they've been deleted. The process operation manages attributes and dependent
- operations upon the objects of one of the involved mappers."""
raise NotImplementedError()
def whose_dependent_on_who(self, obj1, obj2):
- """given an object pair assuming obj2 is a child of obj1, returns a tuple
- with the dependent object second, or None if they are equal.
- used by objectstore's object-level topological sort (i.e. cyclical
- table dependency)."""
+ """Given an object pair assuming `obj2` is a child of `obj1`,
+ return a tuple with the dependent object second, or None if
+ they are equal.
+
+ Used by objectstore's object-level topological sort (i.e. cyclical
+ table dependency).
+ """
+
if obj1 is obj2:
return None
elif self.direction == ONETOMANY:
return (obj1, obj2)
else:
return (obj2, obj1)
-
+
def process_dependencies(self, task, deplist, uowcommit, delete = False):
- """this method is called during a flush operation to synchronize data between a parent and child object.
- it is called within the context of the various mappers and sometimes individual objects sorted according to their
- insert/update/delete order (topological sort)."""
+ """This method is called during a flush operation to
+ synchronize data between a parent and child object.
+
+ It is called within the context of the various mappers and
+ sometimes individual objects sorted according to their
+ insert/update/delete order (topological sort).
+ """
+
raise NotImplementedError()
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
- """used before the flushes' topological sort to traverse through related objects and ensure every
- instance which will require save/update/delete is properly added to the UOWTransaction."""
+ """Used before the flushes' topological sort to traverse
+ through related objects and ensure every instance which will
+ require save/update/delete is properly added to the
+ UOWTransaction.
+ """
+
raise NotImplementedError()
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
- """called during a flush to synchronize primary key identifier values between a parent/child object, as well as
- to an associationrow in the case of many-to-many."""
+ """Called during a flush to synchronize primary key identifier
+ values between a parent/child object, as well as to an
+ associationrow in the case of many-to-many.
+ """
+
raise NotImplementedError()
def _compile_synchronizers(self):
- """assembles a list of 'synchronization rules', which are instructions on how to populate
- the objects on each side of a relationship. This is done when a DependencyProcessor is
+ """Assemble a list of *synchronization rules*, which are
+ instructions on how to populate the objects on each side of a
+ relationship. This is done when a ``DependencyProcessor`` is
first initialized.
- The list of rules is used within commits by the _synchronize() method when dependent
- objects are processed."""
+ The list of rules is used within commits by the ``_synchronize()``
+ method when dependent objects are processed.
+ """
+
self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
if self.direction == sync.MANYTOMANY:
self.syncrules.compile(self.prop.primaryjoin, issecondary=False, foreign_keys=self.foreign_keys)
self.syncrules.compile(self.prop.secondaryjoin, issecondary=True, foreign_keys=self.foreign_keys)
else:
self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys)
-
+
def get_object_dependencies(self, obj, uowcommit, passive = True):
- """returns the list of objects that are dependent on the given object, as according to the relationship
- this dependency processor represents"""
+ """Return the list of objects that are dependent on the given
+ object, as according to the relationship this dependency
+ processor represents.
+ """
+
return sessionlib.attribute_manager.get_history(obj, self.key, passive = passive)
def _conditional_post_update(self, obj, uowcommit, related):
- """execute a post_update call.
-
- for relations that contain the post_update flag, an additional UPDATE statement may be
- associated after an INSERT or before a DELETE in order to resolve circular row dependencies.
- This method will check for the post_update flag being set on a particular relationship, and
- given a target object and list of one or more related objects, and execute the UPDATE if the
- given related object list contains INSERTs or DELETEs."""
+ """Execute a post_update call.
+
+ For relations that contain the post_update flag, an additional
+ ``UPDATE`` statement may be associated after an ``INSERT`` or
+ before a ``DELETE`` in order to resolve circular row
+ dependencies.
+
+ This method will check for the post_update flag being set on a
+ particular relationship, and given a target object and list of
+ one or more related objects, and execute the ``UPDATE`` if the
+ given related object list contains ``INSERT``s or ``DELETE``s.
+ """
+
if obj is not None and self.post_update:
for x in related:
if x is not None:
@@ -127,6 +164,7 @@ class OneToManyDP(DependencyProcessor):
else:
uowcommit.register_dependency(self.parent, self.mapper)
uowcommit.register_processor(self.parent, self, self.parent)
+
def process_dependencies(self, task, deplist, uowcommit, delete = False):
#print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
if delete:
@@ -201,14 +239,14 @@ class OneToManyDP(DependencyProcessor):
uowcommit.register_object(child, isdelete=True)
for c in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
-
+
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
source = obj
dest = child
if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
return
self.syncrules.execute(source, dest, obj, child, clearkeys)
-
+
class ManyToOneDP(DependencyProcessor):
def register_dependencies(self, uowcommit):
if self.post_update:
@@ -220,6 +258,7 @@ class ManyToOneDP(DependencyProcessor):
else:
uowcommit.register_dependency(self.mapper, self.parent)
uowcommit.register_processor(self.mapper, self, self.parent)
+
def process_dependencies(self, task, deplist, uowcommit, delete = False):
#print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
if delete:
@@ -238,6 +277,7 @@ class ManyToOneDP(DependencyProcessor):
for child in childlist.added_items():
self._synchronize(obj, child, None, False, uowcommit)
self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items())
+
def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
#print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " PRE process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
if self.post_update:
@@ -263,7 +303,7 @@ class ManyToOneDP(DependencyProcessor):
uowcommit.register_object(child, isdelete=True)
for c in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
-
+
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
source = child
dest = obj
@@ -280,14 +320,14 @@ class ManyToManyDP(DependencyProcessor):
# association table.
if self.is_backref:
- # if we are the "backref" half of a two-way backref
+ # if we are the "backref" half of a two-way backref
# relationship, let the other mapper handle inserting the rows
return
stub = MapperStub(self.parent, self.mapper, self.key)
uowcommit.register_dependency(self.parent, stub)
uowcommit.register_dependency(self.mapper, stub)
uowcommit.register_processor(stub, self, self.parent)
-
+
def process_dependencies(self, task, deplist, uowcommit, delete = False):
#print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
connection = uowcommit.transaction.connection(self.mapper)
@@ -333,6 +373,7 @@ class ManyToManyDP(DependencyProcessor):
uowcommit.register_object(child, isdelete=True)
for c in self.mapper.cascade_iterator('delete', child):
uowcommit.register_object(c, isdelete=True)
+
def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
dest = associationrow
source = None
@@ -347,23 +388,31 @@ class AssociationDP(OneToManyDP):
self.cascade.delete_orphan = True
class MapperStub(object):
- """poses as a Mapper representing the association table in a many-to-many
- join, when performing a flush().
+ """Pose as a Mapper representing the association table in a
+ many-to-many join, when performing a ``flush()``.
+
+ The ``Task`` objects in the objectstore module treat it just like
+ any other ``Mapper``, but in fact it only serves as a *dependency*
+ placeholder for the many-to-many update task.
+ """
- The Task objects in the objectstore module treat it just like
- any other Mapper, but in fact it only serves as a "dependency" placeholder
- for the many-to-many update task."""
__metaclass__ = util.ArgSingleton
+
def __init__(self, parent, mapper, key):
self.mapper = mapper
self._inheriting_mappers = []
+
def register_dependencies(self, uowcommit):
pass
+
def save_obj(self, *args, **kwargs):
pass
+
def delete_obj(self, *args, **kwargs):
pass
+
def primary_mapper(self):
return self
+
def base_mapper(self):
- return self
\ No newline at end of file
+ return self
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 0327205ce9..91f58e833e 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -8,65 +8,106 @@
from sqlalchemy import util, logging
class MapperProperty(object):
- """manages the relationship of a Mapper to a single class attribute, as well
- as that attribute as it appears on individual instances of the class, including
- attribute instrumentation, attribute access, loading behavior, and dependency calculations."""
+ """Manage the relationship of a ``Mapper`` to a single class
+ attribute, as well as that attribute as it appears on individual
+ instances of the class, including attribute instrumentation,
+ attribute access, loading behavior, and dependency calculations.
+ """
+
def setup(self, querycontext, **kwargs):
- """called when a statement is being constructed. """
+ """Called when a statement is being constructed."""
+
pass
+
def execute(self, selectcontext, instance, row, identitykey, isnew):
- """called when the mapper receives a row. instance is the parent instance
- corresponding to the row. """
+ """Called when the mapper receives a row.
+
+ `instance` is the parent instance corresponding to the `row`.
+ """
+
raise NotImplementedError()
+
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
return []
+
def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
return []
+
def get_criterion(self, query, key, value):
- """Returns a WHERE clause suitable for this MapperProperty corresponding to the
- given key/value pair, where the key is a column or object property name, and value
- is a value to be matched. This is only picked up by PropertyLoaders.
-
- this is called by a Query's join_by method to formulate a set of key/value pairs into
- a WHERE criterion that spans multiple tables if needed."""
+ """Return a ``WHERE`` clause suitable for this
+ ``MapperProperty`` corresponding to the given key/value pair,
+ where the key is a column or object property name, and value
+ is a value to be matched. This is only picked up by
+ ``PropertyLoaders``.
+
+ This is called by a ``Query``'s ``join_by`` method to
+ formulate a set of key/value pairs into a ``WHERE`` criterion
+ that spans multiple tables if needed.
+ """
+
return None
+
def set_parent(self, parent):
self.parent = parent
+
def init(self, key, parent):
- """called after all mappers are compiled to assemble relationships between
- mappers, establish instrumented class attributes"""
+ """Called after all mappers are compiled to assemble
+ relationships between mappers, establish instrumented class
+ attributes.
+ """
+
self.key = key
self.do_init()
+
def do_init(self):
- """template method for subclasses"""
+ """Template method for subclasses."""
pass
+
def register_dependencies(self, *args, **kwargs):
- """called by the Mapper in response to the UnitOfWork calling the Mapper's
- register_dependencies operation. Should register with the UnitOfWork all
- inter-mapper dependencies as well as dependency processors (see UOW docs for more details)"""
+ """Called by the ``Mapper`` in response to the UnitOfWork
+ calling the ``Mapper``'s register_dependencies operation.
+ Should register with the UnitOfWork all inter-mapper
+ dependencies as well as dependency processors (see UOW docs
+ for more details).
+ """
+
pass
def is_primary(self):
- """return True if this MapperProperty's mapper is the primary mapper for its class.
-
- This flag is used to indicate that the MapperProperty can define attribute instrumentation
- for the class at the class level (as opposed to the individual instance level.)"""
+ """Return True if this ``MapperProperty``'s mapper is the
+ primary mapper for its class.
+
+ This flag is used to indicate that the ``MapperProperty`` can
+ define attribute instrumentation for the class at the class
+ level (as opposed to the individual instance level).
+ """
+
return self.parent._is_primary_mapper()
+
def merge(self, session, source, dest):
- """merges the attribute represented by this MapperProperty from source to destination object"""
+ """Merge the attribute represented by this ``MapperProperty``
+ from source to destination object"""
+
raise NotImplementedError()
+
def compare(self, value):
- """returns a compare operation for the columns represented by this MapperProperty to the given value,
- which may be a column value or an instance."""
+ """Return a compare operation for the columns represented by
+ this ``MapperProperty`` to the given value, which may be a
+ column value or an instance.
+ """
+
raise NotImplementedError()
class SynonymProperty(MapperProperty):
def __init__(self, name, proxy=False):
self.name = name
self.proxy = proxy
+
def setup(self, querycontext, **kwargs):
pass
+
def execute(self, selectcontext, instance, row, identitykey, isnew):
pass
+
def do_init(self):
if not self.proxy:
return
@@ -80,13 +121,19 @@ class SynonymProperty(MapperProperty):
return s
return getattr(obj, self.name)
setattr(self.parent.class_, self.key, SynonymProp())
+
def merge(self, session, source, dest, _recursive):
pass
class StrategizedProperty(MapperProperty):
- """a MapperProperty which uses selectable strategies to affect loading behavior.
- There is a single default strategy selected, and alternate strategies can be selected
- at selection time through the usage of StrategizedOption objects."""
+ """A MapperProperty which uses selectable strategies to affect
+ loading behavior.
+
+ There is a single default strategy selected, and alternate
+ strategies can be selected at selection time through the usage of
+ ``StrategizedOption`` objects.
+ """
+
def _get_context_strategy(self, context):
try:
return context.attributes[id(self)]
@@ -95,6 +142,7 @@ class StrategizedProperty(MapperProperty):
ctx_strategy = self._get_strategy(context.attributes.get((LoaderStrategy, self), self.strategy.__class__))
context.attributes[id(self)] = ctx_strategy
return ctx_strategy
+
def _get_strategy(self, cls):
try:
return self._all_strategies[cls]
@@ -105,10 +153,13 @@ class StrategizedProperty(MapperProperty):
strategy.is_default = False
self._all_strategies[cls] = strategy
return strategy
+
def setup(self, querycontext, **kwargs):
self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
+
def execute(self, selectcontext, instance, row, identitykey, isnew):
self._get_context_strategy(selectcontext).process_row(selectcontext, instance, row, identitykey, isnew)
+
def do_init(self):
self._all_strategies = {}
self.strategy = self.create_strategy()
@@ -118,8 +169,12 @@ class StrategizedProperty(MapperProperty):
self.strategy.init_class_attribute()
class OperationContext(object):
- """serves as a context during a query construction or instance loading operation.
- accepts MapperOption objects which may modify its state before proceeding."""
+ """Serve as a context during a query construction or instance
+ loading operation.
+
+ Accept ``MapperOption`` objects which may modify its state before proceeding.
+ """
+
def __init__(self, mapper, options):
self.mapper = mapper
self.options = options
@@ -127,31 +182,42 @@ class OperationContext(object):
self.recursion_stack = util.Set()
for opt in util.flatten_iterator(options):
self.accept_option(opt)
+
def accept_option(self, opt):
pass
class MapperOption(object):
- """describes a modification to an OperationContext."""
+ """Describe a modification to an OperationContext."""
+
def process_query_context(self, context):
pass
+
def process_selection_context(self, context):
pass
+
def process_query(self, query):
pass
-
+
class PropertyOption(MapperOption):
- """a MapperOption that is applied to a property off the mapper
- or one of its child mappers, identified by a dot-separated key."""
+ """A MapperOption that is applied to a property off the mapper or
+ one of its child mappers, identified by a dot-separated key.
+ """
+
def __init__(self, key):
self.key = key
+
def process_query_property(self, context, property):
pass
+
def process_selection_property(self, context, property):
pass
+
def process_query_context(self, context):
self.process_query_property(context, self._get_property(context))
+
def process_selection_context(self, context):
self.process_selection_property(context, self._get_property(context))
+
def _get_property(self, context):
try:
prop = self.__prop
@@ -164,47 +230,64 @@ class PropertyOption(MapperOption):
mapper = getattr(prop, 'mapper', None)
self.__prop = prop
return prop
+
PropertyOption.logger = logging.class_logger(PropertyOption)
class StrategizedOption(PropertyOption):
- """a MapperOption that affects which LoaderStrategy will be used for an operation
- by a StrategizedProperty."""
+ """A MapperOption that affects which LoaderStrategy will be used
+ for an operation by a StrategizedProperty.
+ """
+
def process_query_property(self, context, property):
self.logger.debug("applying option to QueryContext, property key '%s'" % self.key)
context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+
def process_selection_property(self, context, property):
self.logger.debug("applying option to SelectionContext, property key '%s'" % self.key)
context.attributes[(LoaderStrategy, property)] = self.get_strategy_class()
+
def get_strategy_class(self):
raise NotImplementedError()
class LoaderStrategy(object):
- """describes the loading behavior of a StrategizedProperty object. The LoaderStrategy
- interacts with the querying process in three ways:
-
- * it controls the configuration of the InstrumentedAttribute placed on a class to
- handle the behavior of the attribute. this may involve setting up class-level callable
- functions to fire off a select operation when the attribute is first accessed (i.e. a lazy load)
-
- * it processes the QueryContext at statement construction time, where it can modify the SQL statement
- that is being produced. simple column attributes may add their represented column to the list of
- selected columns, "eager loading" properties may add LEFT OUTER JOIN clauses to the statement.
-
- * it processes the SelectionContext at row-processing time. This may involve setting instance-level
- lazyloader functions on newly constructed instances, or may involve recursively appending child items
- to a list in response to additionally eager-loaded objects in the query.
+ """Describe the loading behavior of a StrategizedProperty object.
+
+ The ``LoaderStrategy`` interacts with the querying process in three
+ ways:
+
+ * it controls the configuration of the ``InstrumentedAttribute``
+ placed on a class to handle the behavior of the attribute. this
+ may involve setting up class-level callable functions to fire
+ off a select operation when the attribute is first accessed
+ (i.e. a lazy load)
+
+ * it processes the ``QueryContext`` at statement construction time,
+ where it can modify the SQL statement that is being produced.
+ simple column attributes may add their represented column to the
+ list of selected columns, *eager loading* properties may add
+ ``LEFT OUTER JOIN`` clauses to the statement.
+
+ * it processes the SelectionContext at row-processing time. This
+ may involve setting instance-level lazyloader functions on newly
+ constructed instances, or may involve recursively appending
+ child items to a list in response to additionally eager-loaded
+ objects in the query.
"""
+
def __init__(self, parent):
self.parent_property = parent
self.is_default = True
+
def init(self):
self.parent = self.parent_property.parent
self.key = self.parent_property.key
+
def init_class_attribute(self):
pass
+
def setup_query(self, context, **kwargs):
pass
+
def process_row(self, selectcontext, instance, row, identitykey, isnew):
pass
-
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 9c48682141..7779b99ae1 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -26,22 +26,26 @@ NO_ATTRIBUTE = object()
# returned by a MapperExtension method to indicate a "do nothing" response
EXT_PASS = object()
-
+
# lock used to synchronize the "mapper compile" step
_COMPILE_MUTEX = util.threading.Lock()
-
+
class Mapper(object):
- """Defines the correlation of class attributes to database table columns.
-
- Instances of this class should be constructed via the sqlalchemy.orm.mapper() function."""
- def __init__(self,
- class_,
- local_table,
- properties = None,
- primary_key = None,
+ """Define the correlation of class attributes to database table
+ columns.
+
+ Instances of this class should be constructed via the
+ ``sqlalchemy.orm.mapper()`` function.
+ """
+
+ def __init__(self,
+ class_,
+ local_table,
+ properties = None,
+ primary_key = None,
non_primary = False,
- inherits = None,
- inherit_condition = None,
+ inherits = None,
+ inherit_condition = None,
extension = None,
order_by = False,
allow_column_override = False,
@@ -56,76 +60,130 @@ class Mapper(object):
allow_null_pks=False,
batch=True,
column_prefix=None):
- """construct a new mapper.
-
- All arguments may be sent to the sqlalchemy.orm.mapper() function where they are
- passed through to here.
-
- class_ - the class to be mapped.
-
- local_table - the table to which the class is mapped, or None if this mapper inherits
- from another mapper using concrete table inheritance.
-
- properties - a dictionary mapping the string names of object attributes to MapperProperty
- instances, which define the persistence behavior of that attribute. Note that the columns in the
- mapped table are automatically converted into ColumnProperty instances based on the "key"
- property of each Column (although they can be overridden using this dictionary).
-
- primary_key - a list of Column objects which define the "primary key" to be used against this mapper's
- selectable unit. This is normally simply the primary key of the "local_table", but can be overridden here.
-
- non_primary - construct a Mapper that will define only the selection of instances, not their persistence.
-
- inherits - another Mapper for which this Mapper will have an inheritance relationship with.
-
- inherit_condition - for joined table inheritance, a SQL expression (constructed ClauseElement) which
- will define how the two tables are joined; defaults to a natural join between the two tables.
-
- extension - a MapperExtension instance or list of MapperExtension instances which will be applied to
- all operations by this Mapper.
-
- order_by - a single Column or list of Columns for which selection operations should use as the default
- ordering for entities. Defaults to the OID/ROWID of the table if any, or the first primary key column of the table.
-
- allow_column_override - if True, allows the usage of a `relation()` which has the same name as a column in the mapped table.
- The table column will no longer be mapped.
-
- entity_name - a name to be associated with the class, to allow alternate mappings for a single class.
-
- always_refresh - if True, all query operations for this mapped class will overwrite all data
- within object instances that already exist within the session, erasing any in-memory changes with whatever
- information was loaded from the database.
-
- version_id_col - a Column which must have an integer type that will be used to keep a running "version id" of
- mapped entities in the database. this is used during save operations to ensure that no other thread or process
- has updated the instance during the lifetime of the entity, else a ConcurrentModificationError exception is thrown.
-
- polymorphic_on - used with mappers in an inheritance relationship, a Column which will identify the class/mapper
- combination to be used with a particular row. requires the polymorphic_identity value to be set for all mappers
- in the inheritance hierarchy.
-
- _polymorphic_map - used internally to propigate the full map of polymorphic identifiers to surrogate mappers.
-
- polymorphic_identity - a value which will be stored in the Column denoted by polymorphic_on, corresponding to the
- "class identity" of this mapper.
-
- concrete - if True, indicates this mapper should use concrete table inheritance with its parent mapper.
-
- select_table - a Table or (more commonly) Selectable which will be used to select instances of this mapper's class.
- usually used to provide polymorphic loading among several classes in an inheritance hierarchy.
-
- allow_null_pks - indicates that composite primary keys where one or more (but not all) columns contain NULL is a valid
- primary key. Primary keys which contain NULL values usually indicate that a result row does not contain an entity
- and should be skipped.
-
- batch - indicates that save operations of multiple entities can be batched together for efficiency.
- setting to False indicates that an instance will be fully saved before saving the next instance, which
- includes inserting/updating all table rows corresponding to the entity as well as calling all MapperExtension
- methods corresponding to the save operation.
-
- column_prefix - a string which will be prepended to the "key" name of all Columns when creating column-based
- properties from the given Table. does not affect explicitly specified column-based properties
+ """Construct a new mapper.
+
+ All arguments may be sent to the ``sqlalchemy.orm.mapper()``
+ function where they are passed through to here.
+
+ class_
+ The class to be mapped.
+
+ local_table
+ The table to which the class is mapped, or None if this
+ mapper inherits from another mapper using concrete table
+ inheritance.
+
+ properties
+ A dictionary mapping the string names of object attributes
+ to ``MapperProperty`` instances, which define the
+ persistence behavior of that attribute. Note that the
+ columns in the mapped table are automatically converted into
+ ``ColumnProperty`` instances based on the `key` property of
+ each ``Column`` (although they can be overridden using this
+ dictionary).
+
+ primary_key
+ A list of ``Column`` objects which define the *primary key*
+ to be used against this mapper's selectable unit. This is
+ normally simply the primary key of the `local_table`, but
+ can be overridden here.
+
+ non_primary
+ Construct a ``Mapper`` that will define only the selection
+ of instances, not their persistence.
+
+ inherits
+ Another ``Mapper`` for which this ``Mapper`` will have an
+ inheritance relationship with.
+
+ inherit_condition
+ For joined table inheritance, a SQL expression (constructed
+ ``ClauseElement``) which will define how the two tables are
+ joined; defaults to a natural join between the two tables.
+
+ extension
+ A ``MapperExtension`` instance or list of
+ ``MapperExtension`` instances which will be applied to all
+ operations by this ``Mapper``.
+
+ order_by
+ A single ``Column`` or list of ``Columns`` for which
+ selection operations should use as the default ordering for
+ entities. Defaults to the OID/ROWID of the table if any, or
+ the first primary key column of the table.
+
+ allow_column_override
+ If True, allows the usage of a ``relation()`` which has the
+ same name as a column in the mapped table. The table column
+ will no longer be mapped.
+
+ entity_name
+ A name to be associated with the `class`, to allow alternate
+ mappings for a single class.
+
+ always_refresh
+ If True, all query operations for this mapped class will
+ overwrite all data within object instances that already
+ exist within the session, erasing any in-memory changes with
+ whatever information was loaded from the database.
+
+ version_id_col
+ A ``Column`` which must have an integer type that will be
+ used to keep a running *version id* of mapped entities in
+ the database. this is used during save operations to ensure
+ that no other thread or process has updated the instance
+ during the lifetime of the entity, else a
+ ``ConcurrentModificationError`` exception is thrown.
+
+ polymorphic_on
+ Used with mappers in an inheritance relationship, a ``Column``
+ which will identify the class/mapper combination to be used
+ with a particular row. requires the polymorphic_identity
+ value to be set for all mappers in the inheritance
+ hierarchy.
+
+ _polymorphic_map
+ Used internally to propigate the full map of polymorphic
+ identifiers to surrogate mappers.
+
+ polymorphic_identity
+ A value which will be stored in the Column denoted by
+ polymorphic_on, corresponding to the *class identity* of
+ this mapper.
+
+ concrete
+ If True, indicates this mapper should use concrete table
+ inheritance with its parent mapper.
+
+ select_table
+ A ``Table`` or (more commonly) ``Selectable`` which will be
+ used to select instances of this mapper's class. usually
+ used to provide polymorphic loading among several classes in
+ an inheritance hierarchy.
+
+ allow_null_pks
+ Indicates that composite primary keys where one or more (but
+ not all) columns contain NULL is a valid primary key.
+ Primary keys which contain NULL values usually indicate that
+ a result row does not contain an entity and should be
+ skipped.
+
+ batch
+ Indicates that save operations of multiple entities can be
+ batched together for efficiency. setting to False indicates
+ that an instance will be fully saved before saving the next
+ instance, which includes inserting/updating all table rows
+ corresponding to the entity as well as calling all
+ ``MapperExtension`` methods corresponding to the save
+ operation.
+
+ column_prefix
+ A string which will be prepended to the `key` name of all
+ Columns when creating column-based properties from the given
+ Table. Does not affect explicitly specified column-based
+ properties
"""
+
if not issubclass(class_, object):
raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
@@ -133,7 +191,7 @@ class Mapper(object):
if table is not None and isinstance(table, sql._SelectBaseMixin):
# some db's, noteably postgres, dont want to select from a select
# without an alias. also if we make our own alias internally, then
- # the configured properties on the mapper are not matched against the alias
+ # the configured properties on the mapper are not matched against the alias
# we make, theres workarounds but it starts to get really crazy (its crazy enough
# the SQL that gets generated) so just require an alias
raise exceptions.ArgumentError("Mapping against a Select object requires that it has a name. Use an alias to give it a name, i.e. s = select(...).alias('myselect')")
@@ -159,16 +217,16 @@ class Mapper(object):
self.delete_orphans = []
self.batch = batch
self.column_prefix = column_prefix
- # a Column which is used during a select operation to retrieve the
+ # a Column which is used during a select operation to retrieve the
# "polymorphic identity" of the row, which indicates which Mapper should be used
# to construct a new object instance from that row.
self.polymorphic_on = polymorphic_on
self._eager_loaders = util.Set()
-
+
# our 'polymorphic identity', a string name that when located in a result set row
# indicates this Mapper should be used to construct the object instance for that row.
self.polymorphic_identity = polymorphic_identity
-
+
# a dictionary of 'polymorphic identity' names, associating those names with
# Mappers that will be used to construct object instances upon a select operation.
if _polymorphic_map is None:
@@ -183,44 +241,44 @@ class Mapper(object):
self.compile()
return s.__dict__['_data']
_data = property(_get_data)
-
+
self.columns = LOrderedProp()
self.c = self.columns
-
+
# each time the options() method is called, the resulting Mapper is
# stored in this dictionary based on the given options for fast re-access
self._options = {}
-
+
# a set of all mappers which inherit from this one.
self._inheriting_mappers = util.Set()
-
+
# a second mapper that is used for selecting, if the "select_table" argument
# was sent to this mapper.
self.__surrogate_mapper = None
-
+
# whether or not our compile() method has been called already.
self.__is_compiled = False
# if this mapper is to be a primary mapper (i.e. the non_primary flag is not set),
# associate this Mapper with the given class_ and entity name. subsequent
- # calls to class_mapper() for the class_/entity name combination will return this
+ # calls to class_mapper() for the class_/entity name combination will return this
# mapper.
self._compile_class()
self.__should_log_debug = logging.is_debug_enabled(self.logger)
self.__log("constructed")
-
+
# uncomment to compile at construction time (the old way)
# this will break mapper setups that arent declared in the order
# of dependency
#self.compile()
-
+
def __log(self, msg):
self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.name or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") + ") " + msg)
-
+
def __log_debug(self, msg):
self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.name or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "") + ") " + msg)
-
+
def _is_orphan(self, obj):
optimistic = has_identity(obj)
for (key,klass) in self.delete_orphans:
@@ -231,24 +289,28 @@ class Mapper(object):
if not has_identity(obj):
raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
(
- obj,
+ obj,
", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in self.delete_orphans])
))
else:
return True
else:
return False
-
+
def _get_props(self):
self.compile()
return self.__props
- props = property(_get_props, doc="compiles this mapper if needed, and returns the \
- dictionary of MapperProperty objects associated with this mapper.")
-
+
+ props = property(_get_props, doc="compiles this mapper if needed, and returns the "
+ "dictionary of MapperProperty objects associated with this mapper.")
+
def compile(self):
- """compile this mapper into its final internal format.
-
- this is the 'external' version of the method which is not reentrant."""
+ """Compile this mapper into its final internal format.
+
+ This is the *external* version of the method which is not
+ reentrant.
+ """
+
if self.__is_compiled:
return self
_COMPILE_MUTEX.acquire()
@@ -257,39 +319,40 @@ class Mapper(object):
if self.__is_compiled:
return self
self._compile_all()
-
+
# if we're not primary, compile us
if self.non_primary:
self._do_compile()
self._initialize_properties()
-
+
return self
finally:
_COMPILE_MUTEX.release()
-
+
def _compile_all(self):
# compile all primary mappers
for mapper in mapper_registry.values():
if not mapper.__is_compiled:
mapper._do_compile()
-
+
# initialize properties on all mappers
for mapper in mapper_registry.values():
if not mapper.__props_init:
mapper._initialize_properties()
-
+
def _check_compile(self):
if self.non_primary:
self._do_compile()
self._initialize_properties()
return self
-
+
def _do_compile(self):
- """compile this mapper into its final internal format.
-
- this is the 'internal' version of the method which is assumed to be called within compile()
- and is reentrant.
+ """Compile this mapper into its final internal format.
+
+ This is the *internal* version of the method which is assumed
+ to be called within compile() and is reentrant.
"""
+
if self.__is_compiled:
return self
self.__log("_do_compile() started")
@@ -302,10 +365,13 @@ class Mapper(object):
self._compile_selectable()
self.__log("_do_compile() complete")
return self
-
+
def _compile_extensions(self):
- """goes through the global_extensions list as well as the list of MapperExtensions
- specified for this Mapper and creates a linked list of those extensions."""
+ """Go through the global_extensions list as well as the list
+ of ``MapperExtensions`` specified for this ``Mapper`` and
+ creates a linked list of those extensions.
+ """
+
extlist = util.Set()
for ext_class in global_extensions:
if isinstance(ext_class, MapperExtension):
@@ -321,12 +387,17 @@ class Mapper(object):
self.extension = _ExtensionCarrier()
for ext in extlist:
self.extension.append(ext)
-
+
def _compile_inheritance(self):
- """determines if this Mapper inherits from another mapper, and if so calculates the mapped_table
- for this Mapper taking the inherited mapper into account. for joined table inheritance, creates
- a SyncRule that will synchronize column values between the joined tables. also initializes polymorphic variables
- used in polymorphic loads."""
+ """Determine if this Mapper inherits from another mapper, and
+ if so calculates the mapped_table for this Mapper taking the
+ inherited mapper into account.
+
+ For joined table inheritance, creates a ``SyncRule`` that will
+ synchronize column values between the joined tables. also
+ initializes polymorphic variables used in polymorphic loads.
+ """
+
if self.inherits is not None:
if isinstance(self.inherits, type):
self.inherits = class_mapper(self.inherits, compile=False)._do_compile()
@@ -348,11 +419,11 @@ class Mapper(object):
else:
if self.inherit_condition is None:
# figure out inherit condition from our table to the immediate table
- # of the inherited mapper, not its full table which could pull in other
+ # of the inherited mapper, not its full table which could pull in other
# stuff we dont want (allows test/inheritance.InheritTest4 to pass)
self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause
self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
- # generate sync rules. similarly to creating the on clause, specify a
+ # generate sync rules. similarly to creating the on clause, specify a
# stricter set of tables to create "sync rules" by,based on the immediate
# inherited table, rather than all inherited tables
self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
@@ -377,21 +448,25 @@ class Mapper(object):
if self.mapped_table is None:
raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self))
-
+
# convert polymorphic class associations to mappers
for key in self.polymorphic_map.keys():
if isinstance(self.polymorphic_map[key], type):
self.polymorphic_map[key] = class_mapper(self.polymorphic_map[key])
def _add_polymorphic_mapping(self, key, class_or_mapper, entity_name=None):
- """adds a Mapper to our 'polymorphic map' """
+ """Add a Mapper to our *polymorphic map*."""
+
if isinstance(class_or_mapper, type):
class_or_mapper = class_mapper(class_or_mapper, entity_name=entity_name)
self.polymorphic_map[key] = class_or_mapper
def _compile_tables(self):
- """after the inheritance relationships have been reconciled, sets up some more table-based instance
- variables and determines the "primary key" columns for all tables represented by this Mapper."""
+ """After the inheritance relationships have been reconciled,
+ set up some more table-based instance variables and determine
+ the *primary key* columns for all tables represented by this
+ ``Mapper``.
+ """
# summary of the various Selectable units:
# mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table)
@@ -416,7 +491,7 @@ class Mapper(object):
if self.primary_key is not None:
# determine primary keys using user-given list of primary key columns as a guide
#
- # TODO: this might not work very well for joined-table and/or polymorphic
+ # TODO: this might not work very well for joined-table and/or polymorphic
# inheritance mappers since local_table isnt taken into account nor is select_table
# need to test custom primary key columns used with inheriting mappers
for k in self.primary_key:
@@ -444,18 +519,24 @@ class Mapper(object):
self.primary_key = self.pks_by_table[self.mapped_table]
def _compile_properties(self):
- """inspects the properties dictionary sent to the Mapper's constructor as well as the mapped_table, and creates
- MapperProperty objects corresponding to each mapped column and relation. also grabs MapperProperties from the
- inherited mapper, if any, and creates copies of them to attach to this Mapper."""
+ """Inspect the properties dictionary sent to the Mapper's
+ constructor as well as the mapped_table, and create
+ ``MapperProperty`` objects corresponding to each mapped column
+ and relation.
+
+ Also grab ``MapperProperties`` from the inherited mapper, if
+ any, and create copies of them to attach to this Mapper.
+ """
+
# object attribute names mapped to MapperProperty objects
self.__props = {}
# table columns mapped to lists of MapperProperty objects
- # using a list allows a single column to be defined as
+ # using a list allows a single column to be defined as
# populating multiple object attributes
self.columntoproperty = mapperutil.TranslatingDict(self.mapped_table)
- # load custom properties
+ # load custom properties
if self.properties is not None:
for key, prop in self.properties.iteritems():
self._compile_property(key, prop, False)
@@ -500,9 +581,14 @@ class Mapper(object):
def _initialize_properties(self):
- """calls the init() method on all MapperProperties attached to this mapper. this happens
- after all mappers have completed compiling everything else up until this point, so that all
- dependencies are fully available."""
+ """Call the ``init()`` method on all ``MapperProperties``
+ attached to this mapper.
+
+ This happens after all mappers have completed compiling
+ everything else up until this point, so that all dependencies
+ are fully available.
+ """
+
self.__log("_initialize_properties() started")
l = [(key, prop) for key, prop in self.__props.iteritems()]
for key, prop in l:
@@ -510,13 +596,19 @@ class Mapper(object):
prop.init(key, self)
self.__log("_initialize_properties() complete")
self.__props_init = True
-
+
def _compile_selectable(self):
- """if the 'select_table' keyword argument was specified,
- set up a second "surrogate mapper" that will be used for select operations.
- the columns of select_table should encompass all the columns of the mapped_table either directly
- or through proxying relationships. Currently, non-column properties are *not* copied. this implies
- that a polymorphic mapper cant do any eager loading right now."""
+ """If the 'select_table' keyword argument was specified, set
+ up a second *surrogate mapper* that will be used for select
+ operations.
+
+ The columns of `select_table` should encompass all the columns
+ of the `mapped_table` either directly or through proxying
+ relationships. Currently, non-column properties are **not**
+ copied. This implies that a polymorphic mapper can't do any
+ eager loading right now.
+ """
+
if self.select_table is not self.mapped_table:
props = {}
if self.properties is not None:
@@ -528,18 +620,24 @@ class Mapper(object):
self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on))
def _compile_class(self):
- """if this mapper is to be a primary mapper (i.e. the non_primary flag is not set),
- associate this Mapper with the given class_ and entity name. subsequent
- calls to class_mapper() for the class_/entity name combination will return this
- mapper. also decorates the __init__ method on the mapped class to include optional auto-session attachment logic."""
+ """If this mapper is to be a primary mapper (i.e. the
+ non_primary flag is not set), associate this Mapper with the
+ given class_ and entity name.
+
+ Subsequent calls to ``class_mapper()`` for the class_/entity
+ name combination will return this mapper. Also decorate the
+ `__init__` method on the mapped class to include optional
+ auto-session attachment logic.
+ """
+
if self.non_primary:
return
-
+
if not self.non_primary and (mapper_registry.has_key(self.class_key)):
raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper, or to create a new primary mapper, remove this mapper first via sqlalchemy.orm.clear_mapper(mapper), or preferably sqlalchemy.orm.clear_mappers() to clear all mappers." % (self.class_, self.entity_name))
attribute_manager.reset_class_managed(self.class_)
-
+
oldinit = self.class_.__init__
def init(self, *args, **kwargs):
entity_name = kwargs.pop('_sa_entity_name', None)
@@ -568,7 +666,7 @@ class Mapper(object):
if session is not None and mapper is not None:
self._entity_name = entity_name
session._register_pending(self)
-
+
if oldinit is not None:
try:
oldinit(self, *args, **kwargs)
@@ -589,20 +687,23 @@ class Mapper(object):
mapper_registry[self.class_key] = self
if self.entity_name is None:
self.class_.c = self.c
-
+
def base_mapper(self):
- """return the ultimate base mapper in an inheritance chain"""
+ """Return the ultimate base mapper in an inheritance chain."""
+
if self.inherits is not None:
return self.inherits.base_mapper()
else:
return self
-
+
def common_parent(self, other):
- """return true if the given mapper shares a common inherited parent as this mapper"""
+ """Return true if the given mapper shares a common inherited parent as this mapper."""
+
return self.base_mapper() is other.base_mapper()
-
+
def isa(self, other):
- """return True if the given mapper inherits from this mapper"""
+ """Return True if the given mapper inherits from this mapper."""
+
m = other
while m is not self and m.inherits is not None:
m = m.inherits
@@ -613,26 +714,35 @@ class Mapper(object):
while m is not None:
yield m
m = m.inherits
-
+
def polymorphic_iterator(self):
- """iterates through the collection including this mapper and all descendant mappers.
-
- this includes not just the immediately inheriting mappers but all their inheriting mappers as well.
-
- To iterate through an entire hierarchy, use mapper.base_mapper().polymorphic_iterator()."""
+ """Iterate through the collection including this mapper and
+ all descendant mappers.
+
+ This includes not just the immediately inheriting mappers but
+ all their inheriting mappers as well.
+
+ To iterate through an entire hierarchy, use
+ ``mapper.base_mapper().polymorphic_iterator()``."""
+
def iterate(m):
yield m
for mapper in m._inheriting_mappers:
for x in iterate(mapper):
yield x
return iterate(self)
-
+
def _get_inherited_column_equivalents(self):
- """return a map of all 'equivalent' columns, based on traversing the full set of inherit_conditions across
- all inheriting mappers and determining column pairs that are equated to one another.
-
- this is used when relating columns to those of a polymorphic selectable, as the selectable usually only contains
- one of two columns that are equated to one another."""
+ """Return a map of all *equivalent* columns, based on
+ traversing the full set of inherit_conditions across all
+ inheriting mappers and determining column pairs that are
+ equated to one another.
+
+ This is used when relating columns to those of a polymorphic
+ selectable, as the selectable usually only contains one of two
+ columns that are equated to one another.
+ """
+
result = {}
def visit_binary(binary):
if binary.operator == '=':
@@ -643,24 +753,30 @@ class Mapper(object):
if mapper.inherit_condition is not None:
mapper.inherit_condition.accept_visitor(vis)
return result
-
+
def add_properties(self, dict_of_properties):
- """adds the given dictionary of properties to this mapper, using add_property."""
+ """Add the given dictionary of properties to this mapper,
+ using `add_property`.
+ """
+
for key, value in dict_of_properties.iteritems():
self.add_property(key, value)
def add_property(self, key, prop):
- """add an indiviual MapperProperty to this mapper.
-
- If the mapper has not been compiled yet, just adds the property to the initial
- properties dictionary sent to the constructor. if this Mapper
- has already been compiled, then the given MapperProperty is compiled immediately."""
+ """Add an indiviual MapperProperty to this mapper.
+
+ If the mapper has not been compiled yet, just adds the
+ property to the initial properties dictionary sent to the
+ constructor. If this Mapper has already been compiled, then
+ the given MapperProperty is compiled immediately.
+ """
+
self.properties[key] = prop
if self.__is_compiled:
# if we're compiled, make sure all the other mappers are compiled too
self._compile_all()
self._compile_property(key, prop, init=True)
-
+
def _create_prop_from_column(self, column, skipmissing=False):
if sql.is_column(column):
try:
@@ -686,15 +802,19 @@ class Mapper(object):
if not self.concrete:
self._compile_property(key, prop, init=False, setparent=False)
# TODO: concrete properties dont adapt at all right now....will require copies of relations() etc.
-
+
def _compile_property(self, key, prop, init=True, skipmissing=False, setparent=True):
- """add a MapperProperty to this or another Mapper, including configuration of the property.
-
- The properties' parent attribute will be set, and the property will also be
- copied amongst the mappers which inherit from this one.
-
- if the given prop is a Column or list of Columns, a ColumnProperty will be created.
+ """Add a ``MapperProperty`` to this or another ``Mapper``,
+ including configuration of the property.
+
+ The properties' parent attribute will be set, and the property
+ will also be copied amongst the mappers which inherit from
+ this one.
+
+ If the given `prop` is a ``Column`` or list of Columns, a
+ ``ColumnProperty`` will be created.
"""
+
self.__log("_compile_property(%s, %s)" % (key, prop.__class__.__name__))
if not isinstance(prop, MapperProperty):
@@ -706,7 +826,7 @@ class Mapper(object):
self.__props[key] = prop
if setparent:
prop.set_parent(self)
-
+
if isinstance(prop, ColumnProperty):
# relate the mapper's "select table" to the given ColumnProperty
col = self.select_table.corresponding_column(prop.columns[0], keys_ok=True, raiseerr=False)
@@ -727,84 +847,112 @@ class Mapper(object):
def __str__(self):
return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.name or str(self.local_table)) + (not self._is_primary_mapper() and "|non-primary" or "")
-
+
def _is_primary_mapper(self):
- """returns True if this mapper is the primary mapper for its class key (class + entity_name)"""
+ """Return True if this mapper is the primary mapper for its class key (class + entity_name)."""
return mapper_registry.get(self.class_key, None) is self
def primary_mapper(self):
- """returns the primary mapper corresponding to this mapper's class key (class + entity_name)"""
+ """Return the primary mapper corresponding to this mapper's class key (class + entity_name)."""
return mapper_registry[self.class_key]
def is_assigned(self, instance):
- """return True if this mapper handles the given instance.
-
- this is dependent not only on class assignment but the optional "entity_name" parameter as well."""
+ """Return True if this mapper handles the given instance.
+
+ This is dependent not only on class assignment but the
+ optional `entity_name` parameter as well.
+ """
+
return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name
def _assign_entity_name(self, instance):
- """assign this Mapper's entity name to the given instance.
-
- subsequent Mapper lookups for this instance will return the primary
- mapper corresponding to this Mapper's class and entity name."""
+ """Assign this Mapper's entity name to the given instance.
+
+ Subsequent Mapper lookups for this instance will return the
+ primary mapper corresponding to this Mapper's class and entity
+ name.
+ """
+
instance._entity_name = self.entity_name
-
+
def get_session(self):
- """return the contextual session provided by the mapper extension chain, if any.
-
- raises InvalidRequestError if a session cannot be retrieved from the extension chain
+ """Return the contextual session provided by the mapper
+ extension chain, if any.
+
+ Raise ``InvalidRequestError`` if a session cannot be retrieved
+ from the extension chain.
"""
+
self.compile()
s = self.extension.get_session()
if s is EXT_PASS:
raise exceptions.InvalidRequestError("No contextual Session is established. Use a MapperExtension that implements get_session or use 'import sqlalchemy.mods.threadlocal' to establish a default thread-local contextual session.")
return s
-
+
def has_eager(self):
- """return True if one of the properties attached to this Mapper is eager loading"""
+ """Return True if one of the properties attached to this
+ Mapper is eager loading.
+ """
+
return len(self._eager_loaders) > 0
-
+
def instances(self, cursor, session, *mappers, **kwargs):
- """return a list of mapped instances corresponding to the rows in a given ResultProxy."""
+ """Return a list of mapped instances corresponding to the rows
+ in a given ResultProxy.
+ """
+
import sqlalchemy.orm.query
return sqlalchemy.orm.Query(self, session).instances(cursor, *mappers, **kwargs)
def identity_key_from_row(self, row):
- """return an identity-map key for use in storing/retrieving an item from the identity map.
+ """Return an identity-map key for use in storing/retrieving an
+ item from the identity map.
- row - a sqlalchemy.engine.base.RowProxy instance or a dictionary corresponding result-set
- ColumnElement instances to their values within a row.
+ row
+ A ``sqlalchemy.engine.base.RowProxy`` instance or a
+ dictionary corresponding result-set ``ColumnElement``
+ instances to their values within a row.
"""
return (self.class_, tuple([row[column] for column in self.pks_by_table[self.mapped_table]]), self.entity_name)
-
+
def identity_key_from_primary_key(self, primary_key):
- """return an identity-map key for use in storing/retrieving an item from an identity map.
-
- primary_key - a list of values indicating the identifier.
+ """Return an identity-map key for use in storing/retrieving an
+ item from an identity map.
+
+ primary_key
+ A list of values indicating the identifier.
"""
return (self.class_, tuple(util.to_list(primary_key)), self.entity_name)
def identity_key_from_instance(self, instance):
- """return the identity key for the given instance, based on its primary key attributes.
-
- this value is typically also found on the instance itself under the attribute name '_instance_key'.
+ """Return the identity key for the given instance, based on
+ its primary key attributes.
+
+ This value is typically also found on the instance itself
+ under the attribute name `_instance_key`.
"""
return self.identity_key_from_primary_key(self.primary_key_from_instance(instance))
def primary_key_from_instance(self, instance):
- """return the list of primary key values for the given instance."""
+ """Return the list of primary key values for the given
+ instance.
+ """
+
return [self.get_attr_by_column(instance, column) for column in self.pks_by_table[self.mapped_table]]
def instance_key(self, instance):
- """deprecated. a synonym for identity_key_from_instance."""
+ """Deprecated. A synonym for `identity_key_from_instance`."""
+
return self.identity_key_from_instance(instance)
def identity_key(self, primary_key):
- """deprecated. a synonym for identity_key_from_primary_key."""
+ """Deprecated. A synonym for `identity_key_from_primary_key`."""
+
return self.identity_key_from_primary_key(primary_key)
def identity(self, instance):
- """deprecated. a synoynm for primary_key_from_instance."""
+ """Deprecated. A synoynm for `primary_key_from_instance`."""
+
return self.primary_key_from_instance(instance)
def _getpropbycolumn(self, column, raiseerror=True):
@@ -821,9 +969,10 @@ class Mapper(object):
return None
raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
return prop[0]
-
+
def get_attr_by_column(self, obj, column, raiseerror=True):
- """return an instance attribute using a Column as the key."""
+ """Return an instance attribute using a Column as the key."""
+
prop = self._getpropbycolumn(column, raiseerror)
if prop is None:
return NO_ATTRIBUTE
@@ -831,28 +980,33 @@ class Mapper(object):
return prop.getattr(obj)
def set_attr_by_column(self, obj, column, value):
- """set the value of an instance attribute using a Column as the key."""
+ """Set the value of an instance attribute using a Column as the key."""
+
self.columntoproperty[column][0].setattr(obj, value)
-
+
def save_obj(self, objects, uowtransaction, postupdate=False, post_update_cols=None, single=False):
- """issue INSERT and/or UPDATE statements for a list of objects.
-
- this is called within the context of a UOWTransaction during a flush operation.
-
- save_obj issues SQL statements not just for instances mapped directly by this mapper, but
- for instances mapped by all inheriting mappers as well. This is to maintain proper insert
- ordering among a polymorphic chain of instances. Therefore save_obj is typically
- called only on a "base mapper", or a mapper which does not inherit from any other mapper."""
-
+ """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation.
+
+ `save_obj` issues SQL statements not just for instances mapped
+ directly by this mapper, but for instances mapped by all
+ inheriting mappers as well. This is to maintain proper insert
+ ordering among a polymorphic chain of instances. Therefore
+ save_obj is typically called only on a *base mapper*, or a
+ mapper which does not inherit from any other mapper.
+ """
+
if self.__should_log_debug:
self.__log_debug("save_obj() start, " + (single and "non-batched" or "batched"))
-
+
# if batch=false, call save_obj separately for each object
if not single and not self.batch:
for obj in objects:
self.save_obj([obj], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
return
-
+
connection = uowtransaction.transaction.connection(self)
if not postupdate:
@@ -866,7 +1020,7 @@ class Mapper(object):
for obj in objects:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
- # and another instance with the same identity key already exists as persistent. convert to an
+ # and another instance with the same identity key already exists as persistent. convert to an
# UPDATE if so.
mapper = object_mapper(obj)
instance_key = mapper.instance_key(obj)
@@ -881,10 +1035,10 @@ class Mapper(object):
if has_identity(obj):
if obj._instance_key != instance_key:
raise exceptions.FlushError("Can't change the identity of instance %s in session (existing identity: %s; new identity: %s)" % (mapperutil.instance_str(obj), obj._instance_key, instance_key))
-
+
inserted_objects = util.Set()
updated_objects = util.Set()
-
+
table_to_mapper = {}
for mapper in self.base_mapper().polymorphic_iterator():
for t in mapper.tables:
@@ -894,7 +1048,7 @@ class Mapper(object):
# two lists to store parameters for each table/object pair located
insert = []
update = []
-
+
for obj in objects:
mapper = object_mapper(obj)
if table not in mapper.tables or not mapper._has_pks(table):
@@ -920,7 +1074,7 @@ class Mapper(object):
# matching the bindparam we are creating below, i.e. "_"
params[col._label] = mapper.get_attr_by_column(obj, col)
else:
- # doing an INSERT, primary key col ?
+ # doing an INSERT, primary key col ?
# if the primary key values are not populated,
# leave them out of the INSERT altogether, since PostGres doesn't want
# them to be present for SERIAL to take effect. A SQLEngine that uses
@@ -957,9 +1111,9 @@ class Mapper(object):
params[col.key] = a[0]
hasdata = True
else:
- # doing an INSERT, non primary key col ?
- # add the attribute's value to the
- # bind parameters, unless its None and the column has a
+ # doing an INSERT, non primary key col ?
+ # add the attribute's value to the
+ # bind parameters, unless its None and the column has a
# default. if its None and theres no default, we still might
# not want to put it in the col list but SQLIte doesnt seem to like that
# if theres no columns at all
@@ -1021,7 +1175,7 @@ class Mapper(object):
mapper.set_attr_by_column(obj, col, primary_key[i])
i+=1
mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
-
+
# synchronize newly inserted ids from one table to the next
# TODO: this fires off more than needed, try to organize syncrules
# per table
@@ -1032,7 +1186,7 @@ class Mapper(object):
if mapper._synchronizer is not None:
mapper._synchronizer.execute(obj, obj)
sync(mapper)
-
+
inserted_objects.add(obj)
if not postupdate:
for obj in inserted_objects:
@@ -1043,9 +1197,13 @@ class Mapper(object):
mapper.extension.after_update(mapper, connection, obj)
def _postfetch(self, connection, table, obj, resultproxy, params):
- """after an INSERT or UPDATE, asks the returned result if PassiveDefaults fired off on the database side
- which need to be post-fetched, *or* if pre-exec defaults like ColumnDefaults were fired off
- and should be populated into the instance. this is only for non-primary key columns."""
+ """After an ``INSERT`` or ``UPDATE``, ask the returned result
+ if ``PassiveDefaults`` fired off on the database side which
+ need to be post-fetched, **or** if pre-exec defaults like
+ ``ColumnDefaults`` were fired off and should be populated into
+ the instance. this is only for non-primary key columns.
+ """
+
if resultproxy.lastrow_has_defaults():
clause = sql.and_()
for p in self.pks_by_table[table]:
@@ -1065,9 +1223,11 @@ class Mapper(object):
self.set_attr_by_column(obj, c, params.get_original(c.name))
def delete_obj(self, objects, uowtransaction):
- """issue DELETE statements for a list of objects.
-
- this is called within the context of a UOWTransaction during a flush operation."""
+ """Issue ``DELETE`` statements for a list of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation.
+ """
if self.__should_log_debug:
self.__log_debug("delete_obj() start")
@@ -1108,7 +1268,7 @@ class Mapper(object):
c = connection.execute(statement, delete)
if c.supports_sane_rowcount() and c.rowcount != len(delete):
raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
-
+
[self.extension.after_delete(self, connection, obj) for obj in deleted_objects]
def _has_pks(self, table):
@@ -1120,23 +1280,35 @@ class Mapper(object):
return True
except KeyError:
return False
-
+
def register_dependencies(self, uowcommit, *args, **kwargs):
- """register DependencyProcessor instances with a unitofwork.UOWTransaction.
-
- this calls register_dependencies on all attached MapperProperty instances."""
+ """Register ``DependencyProcessor`` instances with a
+ ``unitofwork.UOWTransaction``.
+
+ This call `register_dependencies` on all attached
+ ``MapperProperty`` instances.
+ """
+
for prop in self.__props.values():
prop.register_dependencies(uowcommit, *args, **kwargs)
-
+
def cascade_iterator(self, type, object, recursive=None, halt_on=None):
- """iterate each element in an object graph, for all relations taht meet the given cascade rule.
-
- type - the name of the cascade rule (i.e. save-update, delete, etc.)
-
- object - the lead object instance. child items will be processed per the relations
- defined for this object's mapper.
-
- recursive - used by the function for internal context during recursive calls, leave as None."""
+ """Iterate each element in an object graph, for all relations
+ taht meet the given cascade rule.
+
+ type
+ The name of the cascade rule (i.e. save-update, delete,
+ etc.)
+
+ object
+ The lead object instance. child items will be processed per
+ the relations defined for this object's mapper.
+
+ recursive
+ Used by the function for internal context during recursive
+ calls, leave as None.
+ """
+
if recursive is None:
recursive=util.Set()
for prop in self.__props.values():
@@ -1144,33 +1316,46 @@ class Mapper(object):
yield c
def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
- """execute a callable for each element in an object graph, for all relations that meet the given cascade rule.
-
- type - the name of the cascade rule (i.e. save-update, delete, etc.)
-
- object - the lead object instance. child items will be processed per the relations
- defined for this object's mapper.
-
- callable_ - the callable function.
-
- recursive - used by the function for internal context during recursive calls, leave as None."""
+ """Execute a callable for each element in an object graph, for
+ all relations that meet the given cascade rule.
+
+ type
+ The name of the cascade rule (i.e. save-update, delete, etc.)
+
+ object
+ The lead object instance. child items will be processed per
+ the relations defined for this object's mapper.
+
+ callable_
+ The callable function.
+
+ recursive
+ Used by the function for internal context during recursive
+ calls, leave as None.
+ """
+
if recursive is None:
recursive=util.Set()
for prop in self.__props.values():
prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
-
def get_select_mapper(self):
- """return the mapper used for issuing selects.
-
- this mapper is the same mapper as 'self' unless the select_table argument was specified for this mapper."""
+ """Return the mapper used for issuing selects.
+
+ This mapper is the same mapper as `self` unless the
+ select_table argument was specified for this mapper.
+ """
+
return self.__surrogate_mapper or self
-
+
def _instance(self, context, row, result = None, skip_polymorphic=False):
- """pulls an object instance from the given row and appends it to the given result
- list. if the instance already exists in the given identity map, its not added. in
- either case, executes all the property loaders on the instance to also process extra
- information in the row."""
+ """Pull an object instance from the given row and append it to
+ the given result list.
+
+ If the instance already exists in the given identity map, its
+ not added. In either case, execute all the property loaders
+ on the instance to also process extra information in the row.
+ """
# apply ExtensionOptions applied to the Query to this mapper,
# but only if our mapper matches.
@@ -1179,7 +1364,7 @@ class Mapper(object):
extension = context.extension
else:
extension = self.extension
-
+
ret = extension.translate_row(self, context, row)
if ret is not EXT_PASS:
row = ret
@@ -1191,11 +1376,11 @@ class Mapper(object):
if mapper is not self:
row = self.translate_row(mapper, row)
return mapper._instance(context, row, result=result, skip_polymorphic=True)
-
+
# look in main identity map. if its there, we dont do anything to it,
# including modifying any of its related items lists, as its already
# been exposed to being modified by the application.
-
+
populate_existing = context.populate_existing or self.always_refresh
identitykey = self.identity_key_from_row(row)
if context.session.has_key(identitykey):
@@ -1205,7 +1390,7 @@ class Mapper(object):
isnew = False
if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
-
+
if populate_existing or context.session.is_expired(instance, unexpire=True):
if not context.identity_map.has_key(identitykey):
context.identity_map[identitykey] = instance
@@ -1220,23 +1405,23 @@ class Mapper(object):
if self.__should_log_debug:
self.__log_debug("_instance(): identity key %s not in session" % str(identitykey) + repr([mapperutil.instance_str(x) for x in context.session]))
# look in result-local identitymap for it.
- exists = context.identity_map.has_key(identitykey)
+ exists = context.identity_map.has_key(identitykey)
if not exists:
if self.allow_null_pks:
- # check if *all* primary key cols in the result are None - this indicates
- # an instance of the object is not present in the row.
+ # check if *all* primary key cols in the result are None - this indicates
+ # an instance of the object is not present in the row.
for col in self.pks_by_table[self.mapped_table]:
if row[col] is not None:
break
else:
return None
else:
- # otherwise, check if *any* primary key cols in the result are None - this indicates
- # an instance of the object is not present in the row.
+ # otherwise, check if *any* primary key cols in the result are None - this indicates
+ # an instance of the object is not present in the row.
for col in self.pks_by_table[self.mapped_table]:
if row[col] is None:
return None
-
+
# plugin point
instance = extension.create_instance(self, context, row, self.class_)
if instance is EXT_PASS:
@@ -1251,7 +1436,7 @@ class Mapper(object):
instance = context.identity_map[identitykey]
isnew = False
- # call further mapper properties on the row, to pull further
+ # call further mapper properties on the row, to pull further
# instances from the row and possibly populate this item.
if extension.populate_instance(self, context, row, instance, identitykey, isnew) is EXT_PASS:
self.populate_instance(context, instance, row, identitykey, isnew)
@@ -1263,7 +1448,7 @@ class Mapper(object):
def _create_instance(self, session):
obj = self.class_.__new__(self.class_)
obj._entity_name = self.entity_name
-
+
# this gets the AttributeManager to do some pre-initialization,
# in order to save on KeyErrors later on
attribute_manager.init_attr(obj)
@@ -1271,21 +1456,23 @@ class Mapper(object):
return obj
def translate_row(self, tomapper, row):
- """translate the column keys of a row into a new or proxied row that
- can be understood by another mapper.
+ """Translate the column keys of a row into a new or proxied
+ row that can be understood by another mapper.
+
+ This can be used in conjunction with populate_instance to
+ populate an instance using an alternate mapper.
+ """
- This can be used in conjunction with populate_instance to populate
- an instance using an alternate mapper."""
newrow = util.DictDecorator(row)
for c in tomapper.mapped_table.c:
c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True)
if row.has_key(c2):
newrow[c] = row[c2]
return newrow
-
+
def populate_instance(self, selectcontext, instance, row, identitykey, isnew):
"""populate an instance from a result row.
-
+
This method iterates through the list of MapperProperty objects attached to this Mapper
and calls each properties execute() method."""
for prop in self.__props.values():
@@ -1295,162 +1482,254 @@ Mapper.logger = logging.class_logger(Mapper)
class MapperExtension(object):
- """base implementation for an object that provides overriding behavior to various
- Mapper functions. For each method in MapperExtension, a result of EXT_PASS indicates
- the functionality is not overridden."""
+ """Base implementation for an object that provides overriding
+ behavior to various Mapper functions. For each method in
+ MapperExtension, a result of EXT_PASS indicates the functionality
+ is not overridden.
+ """
+
def get_session(self):
- """retrieve a contextual Session instance with which to register a new object.
-
- Note: this is not called if a session is provided with the __init__ params (i.e. _sa_session)"""
+ """Retrieve a contextual Session instance with which to
+ register a new object.
+
+ Note: this is not called if a session is provided with the
+ `__init__` params (i.e. `_sa_session`).
+ """
+
return EXT_PASS
+
def load(self, query, *args, **kwargs):
- """override the load method of the Query object.
+ """Override the `load` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.load()`` if the value is anything other than EXT_PASS.
+ """
- the return value of this method is used as the result of query.load() if the
- value is anything other than EXT_PASS."""
return EXT_PASS
+
def get(self, query, *args, **kwargs):
- """override the get method of the Query object.
+ """Override the `get` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.get()`` if the value is anything other than EXT_PASS.
+ """
- the return value of this method is used as the result of query.get() if the
- value is anything other than EXT_PASS."""
return EXT_PASS
+
def get_by(self, query, *args, **kwargs):
- """override the get_by method of the Query object.
+ """Override the `get_by` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.get_by()`` if the value is anything other than
+ EXT_PASS.
+ """
- the return value of this method is used as the result of query.get_by() if the
- value is anything other than EXT_PASS."""
return EXT_PASS
+
def select_by(self, query, *args, **kwargs):
- """override the select_by method of the Query object.
-
- the return value of this method is used as the result of query.select_by() if the
- value is anything other than EXT_PASS."""
+ """Override the `select_by` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.select_by()`` if the value is anything other than
+ EXT_PASS.
+ """
+
return EXT_PASS
+
def select(self, query, *args, **kwargs):
- """override the select method of the Query object.
-
- the return value of this method is used as the result of query.select() if the
- value is anything other than EXT_PASS."""
+ """Override the `select` method of the Query object.
+
+ The return value of this method is used as the result of
+ ``query.select()`` if the value is anything other than
+ EXT_PASS.
+ """
+
return EXT_PASS
-
+
+
def translate_row(self, mapper, context, row):
- """perform pre-processing on the given result row and return a new row instance.
-
- this is called as the very first step in the _instance() method."""
+ """Perform pre-processing on the given result row and return a
+ new row instance.
+
+ This is called as the very first step in the ``_instance()``
+ method.
+ """
+
return EXT_PASS
-
+
def create_instance(self, mapper, selectcontext, row, class_):
- """receieve a row when a new object instance is about to be created from that row.
- the method can choose to create the instance itself, or it can return
- None to indicate normal object creation should take place.
-
- mapper - the mapper doing the operation
-
- selectcontext - SelectionContext corresponding to the instances() call
-
- row - the result row from the database
-
- class_ - the class we are mapping.
+ """Receive a row when a new object instance is about to be
+ created from that row.
+
+ The method can choose to create the instance itself, or it can
+ return None to indicate normal object creation should take
+ place.
+
+ mapper
+ The mapper doing the operation
+
+ selectcontext
+ SelectionContext corresponding to the instances() call
+
+ row
+ The result row from the database
+
+ class_
+ The class we are mapping.
"""
+
return EXT_PASS
+
def append_result(self, mapper, selectcontext, row, instance, identitykey, result, isnew):
- """receive an object instance before that instance is appended to a result list.
-
- If this method returns EXT_PASS, result appending will proceed normally.
- if this method returns any other value or None, result appending will not proceed for
- this instance, giving this extension an opportunity to do the appending itself, if desired.
-
- mapper - the mapper doing the operation
-
- selectcontext - SelectionContext corresponding to the instances() call
-
- row - the result row from the database
-
- instance - the object instance to be appended to the result
-
- identitykey - the identity key of the instance
-
- result - list to which results are being appended
-
- isnew - indicates if this is the first time we have seen this object instance in the current result
- set. if you are selecting from a join, such as an eager load, you might see the same object instance
- many times in the same result set.
+ """Receive an object instance before that instance is appended
+ to a result list.
+
+ If this method returns EXT_PASS, result appending will proceed
+ normally. if this method returns any other value or None,
+ result appending will not proceed for this instance, giving
+ this extension an opportunity to do the appending itself, if
+ desired.
+
+ mapper
+ The mapper doing the operation.
+
+ selectcontext
+ SelectionContext corresponding to the instances() call.
+
+ row
+ The result row from the database.
+
+ instance
+ The object instance to be appended to the result.
+
+ identitykey
+ The identity key of the instance.
+
+ result
+ List to which results are being appended.
+
+ isnew
+ Indicates if this is the first time we have seen this object
+ instance in the current result set. if you are selecting
+ from a join, such as an eager load, you might see the same
+ object instance many times in the same result set.
"""
+
return EXT_PASS
+
def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
- """receive a newly-created instance before that instance has its attributes populated.
-
+ """Receive a newly-created instance before that instance has
+ its attributes populated.
+
The normal population of attributes is according to each
- attribute's corresponding MapperProperty (which includes column-based attributes as well
- as relationships to other classes). If this method returns EXT_PASS, instance population
- will proceed normally. If any other value or None is returned, instance population
- will not proceed, giving this extension an opportunity to populate the instance itself,
- if desired.
+ attribute's corresponding MapperProperty (which includes
+ column-based attributes as well as relationships to other
+ classes). If this method returns EXT_PASS, instance
+ population will proceed normally. If any other value or None
+ is returned, instance population will not proceed, giving this
+ extension an opportunity to populate the instance itself, if
+ desired.
"""
+
return EXT_PASS
+
def before_insert(self, mapper, connection, instance):
- """receive an object instance before that instance is INSERTed into its table.
-
- this is a good place to set up primary key values and such that arent handled otherwise."""
+ """Receive an object instance before that instance is INSERTed
+ into its table.
+
+ This is a good place to set up primary key values and such
+ that aren't handled otherwise.
+ """
+
return EXT_PASS
+
def before_update(self, mapper, connection, instance):
- """receive an object instance before that instance is UPDATEed."""
+ """Receive an object instance before that instance is UPDATEed."""
+
return EXT_PASS
+
def after_update(self, mapper, connection, instance):
- """receive an object instance after that instance is UPDATEed."""
+ """Receive an object instance after that instance is UPDATEed."""
+
return EXT_PASS
+
def after_insert(self, mapper, connection, instance):
- """receive an object instance after that instance is INSERTed."""
+ """Receive an object instance after that instance is INSERTed."""
+
return EXT_PASS
+
def before_delete(self, mapper, connection, instance):
- """receive an object instance before that instance is DELETEed."""
+ """Receive an object instance before that instance is DELETEed."""
+
return EXT_PASS
+
def after_delete(self, mapper, connection, instance):
- """receive an object instance after that instance is DELETEed."""
+ """Receive an object instance after that instance is DELETEed."""
+
return EXT_PASS
class _ExtensionCarrier(MapperExtension):
def __init__(self):
self.__elements = []
+
def insert(self, extension):
- """insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+ """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+
self.__elements.insert(0, extension)
+
def append(self, extension):
- """append a MapperExtension at the end of this ExtensionCarrier's list."""
+ """Append a MapperExtension at the end of this ExtensionCarrier's list."""
+
self.__elements.append(extension)
+
def get_session(self, *args, **kwargs):
return self._do('get_session', *args, **kwargs)
+
def load(self, *args, **kwargs):
return self._do('load', *args, **kwargs)
+
def get(self, *args, **kwargs):
return self._do('get', *args, **kwargs)
+
def get_by(self, *args, **kwargs):
return self._do('get_by', *args, **kwargs)
+
def select_by(self, *args, **kwargs):
return self._do('select_by', *args, **kwargs)
+
def select(self, *args, **kwargs):
return self._do('select', *args, **kwargs)
+
def translate_row(self, *args, **kwargs):
return self._do('translate_row', *args, **kwargs)
+
def create_instance(self, *args, **kwargs):
return self._do('create_instance', *args, **kwargs)
+
def append_result(self, *args, **kwargs):
return self._do('append_result', *args, **kwargs)
+
def populate_instance(self, *args, **kwargs):
return self._do('populate_instance', *args, **kwargs)
+
def before_insert(self, *args, **kwargs):
return self._do('before_insert', *args, **kwargs)
+
def before_update(self, *args, **kwargs):
return self._do('before_update', *args, **kwargs)
+
def after_update(self, *args, **kwargs):
return self._do('after_update', *args, **kwargs)
+
def after_insert(self, *args, **kwargs):
return self._do('after_insert', *args, **kwargs)
+
def before_delete(self, *args, **kwargs):
return self._do('before_delete', *args, **kwargs)
+
def after_delete(self, *args, **kwargs):
return self._do('after_delete', *args, **kwargs)
+
def _do(self, funcname, *args, **kwargs):
for elem in self.__elements:
ret = getattr(elem, funcname)(*args, **kwargs)
@@ -1458,51 +1737,64 @@ class _ExtensionCarrier(MapperExtension):
return ret
else:
return EXT_PASS
-
+
class ExtensionOption(MapperOption):
def __init__(self, ext):
self.ext = ext
+
def process_query(self, query):
query._insert_extension(self.ext)
-
+
class ClassKey(object):
- """keys a class and an entity name to a mapper, via the mapper_registry."""
+ """Key a class and an entity name to a mapper, via the mapper_registry."""
+
__metaclass__ = util.ArgSingleton
+
def __init__(self, class_, entity_name):
self.class_ = class_
self.entity_name = entity_name
+
def __hash__(self):
return hash((self.class_, self.entity_name))
+
def __eq__(self, other):
return self is other
+
def __repr__(self):
return "ClassKey(%s, %s)" % (repr(self.class_), repr(self.entity_name))
+
def dispose(self):
type(self).dispose_static(self.class_, self.entity_name)
-
+
def has_identity(object):
return hasattr(object, '_instance_key')
-
+
def has_mapper(object):
- """return True if the given object has had a mapper association set up, either through loading,
- or via insertion in a session."""
- return hasattr(object, '_entity_name')
+ """Return True if the given object has had a mapper association
+ set up, either through loading, or via insertion in a session.
+ """
+ return hasattr(object, '_entity_name')
-
def object_mapper(object, raiseerror=True):
- """given an object, returns the primary Mapper associated with the object instance"""
+ """Given an object, return the primary Mapper associated with the
+ object instance.
+ """
+
try:
mapper = mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', None))]
- except (KeyError, AttributeError):
+ except (KeyError, AttributeError):
if raiseerror:
raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None)))
else:
return None
return mapper.compile()
-
+
def class_mapper(class_, entity_name=None, compile=True):
- """given a ClassKey, returns the primary Mapper associated with the key."""
+ """Given a ClassKey, return the primary Mapper associated with the
+ key.
+ """
+
try:
mapper = mapper_registry[ClassKey(class_, entity_name)]
except (KeyError, AttributeError):
@@ -1511,5 +1803,3 @@ def class_mapper(class_, entity_name=None, compile=True):
return mapper.compile()
else:
return mapper
-
-
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 930ea10413..cc11e193d7 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -4,10 +4,12 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""defines a set of mapper.MapperProperty objects, including basic column properties as
-well as relationships. the objects rely upon the LoaderStrategy objects in the strategies.py
-module to handle load operations. PropertyLoader also relies upon the dependency.py module
-to handle flush-time dependency sorting and processing."""
+"""Defines a set of mapper.MapperProperty objects, including basic
+column properties as well as relationships. The objects rely upon the
+LoaderStrategy objects in the strategies.py module to handle load
+operations. PropertyLoader also relies upon the dependency.py module
+to handle flush-time dependency sorting and processing.
+"""
from sqlalchemy import sql, schema, util, exceptions, sql_util, logging
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
@@ -15,39 +17,51 @@ from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
import sets, random
from sqlalchemy.orm.interfaces import *
-
+
class ColumnProperty(StrategizedProperty):
- """describes an object attribute that corresponds to a table column."""
+ """Describes an object attribute that corresponds to a table column."""
+
def __init__(self, *columns, **kwargs):
- """the list of columns describes a single object property. if there
- are multiple tables joined together for the mapper, this list represents
- the equivalent column as it appears across each table."""
+ """The list of `columns` describes a single object
+ property. If there are multiple tables joined together for the
+ mapper, this list represents the equivalent column as it
+ appears across each table.
+ """
+
self.columns = list(columns)
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
+
def create_strategy(self):
if self.deferred:
return strategies.DeferredColumnLoader(self)
else:
return strategies.ColumnLoader(self)
+
def getattr(self, object):
return getattr(object, self.key)
+
def setattr(self, object, value):
setattr(object, self.key, value)
+
def get_history(self, obj, passive=False):
return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive)
+
def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
+
def compare(self, value):
return self.columns[0] == value
-
+
ColumnProperty.logger = logging.class_logger(ColumnProperty)
-
+
mapper.ColumnProperty = ColumnProperty
class PropertyLoader(StrategizedProperty):
- """describes an object property that holds a single item or list of items that correspond
- to a related database table."""
+ """Describes an object property that holds a single item or list
+ of items that correspond to a related database table.
+ """
+
def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None):
self.uselist = uselist
self.argument = argument
@@ -64,7 +78,7 @@ class PropertyLoader(StrategizedProperty):
self.passive_deletes = passive_deletes
self.remote_side = util.to_set(remote_side)
self._parent_join_cache = {}
-
+
if cascade is not None:
self.cascade = mapperutil.CascadeOptions(cascade)
else:
@@ -87,10 +101,10 @@ class PropertyLoader(StrategizedProperty):
else:
self.backref = backref
self.is_backref = is_backref
-
+
def compare(self, value):
return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
-
+
private = property(lambda s:s.cascade.delete_orphan)
def create_strategy(self):
@@ -100,7 +114,7 @@ class PropertyLoader(StrategizedProperty):
return strategies.EagerLoader(self)
elif self.lazy is None:
return strategies.NoLoader(self)
-
+
def __str__(self):
return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__) + ")"
@@ -123,7 +137,7 @@ class PropertyLoader(StrategizedProperty):
setattr(dest, self.key, session.merge(current, _recursive=_recursive))
finally:
_recursive.remove(source)
-
+
def cascade_iterator(self, type, object, recursive, halt_on=None):
if not type in self.cascade:
return
@@ -141,7 +155,7 @@ class PropertyLoader(StrategizedProperty):
def cascade_callable(self, type, object, callable_, recursive, halt_on=None):
if not type in self.cascade:
return
-
+
mapper = self.mapper.primary_mapper()
passive = type != 'delete' or self.passive_deletes
for c in sessionlib.attribute_manager.get_as_list(object, self.key, passive=passive):
@@ -153,12 +167,15 @@ class PropertyLoader(StrategizedProperty):
mapper.cascade_callable(type, c, callable_, recursive)
def _get_target_class(self):
- """return the target class of the relation, even if the property has not been initialized yet."""
+ """Return the target class of the relation, even if the
+ property has not been initialized yet.
+ """
+
if isinstance(self.argument, type):
return self.argument
else:
return self.argument.class_
-
+
def do_init(self):
self._determine_targets()
self._determine_joins()
@@ -167,7 +184,7 @@ class PropertyLoader(StrategizedProperty):
self._determine_remote_side()
self._create_polymorphic_joins()
self._post_init()
-
+
def _determine_targets(self):
if isinstance(self.argument, type):
self.mapper = mapper.class_mapper(self.argument, compile=False)._check_compile()
@@ -175,14 +192,14 @@ class PropertyLoader(StrategizedProperty):
self.mapper = self.argument._check_compile()
else:
raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
-
+
# insure the "select_mapper", if different from the regular target mapper, is compiled.
self.mapper.get_select_mapper()._check_compile()
-
+
if self.association is not None:
if isinstance(self.association, type):
self.association = mapper.class_mapper(self.association, compile=False)._check_compile()
-
+
self.target = self.mapper.mapped_table
self.select_mapper = self.mapper.get_select_mapper()
self.select_table = self.mapper.select_table
@@ -192,7 +209,7 @@ class PropertyLoader(StrategizedProperty):
if self.parent.class_ is self.mapper.class_:
raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade rule on a self-referential relationship. You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
-
+
def _determine_joins(self):
if self.secondaryjoin is not None and self.secondary is None:
raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
@@ -256,9 +273,12 @@ class PropertyLoader(StrategizedProperty):
raise exceptions.ArgumentError("Cant locate any foreign key columns in primary join condition '%s' for relationship '%s'. Specify 'foreign_keys' argument to indicate which columns in the join condition are foreign." %(str(self.primaryjoin), str(self)))
if self.secondaryjoin is not None:
self.secondaryjoin.accept_visitor(mapperutil.BinaryVisitor(visit_binary))
-
+
def _determine_direction(self):
- """determines our 'direction', i.e. do we represent one to many, many to many, etc."""
+ """Determine our *direction*, i.e. do we represent one to
+ many, many to many, etc.
+ """
+
if self.secondaryjoin is not None:
self.direction = sync.MANYTOMANY
elif self._is_self_referential():
@@ -271,7 +291,7 @@ class PropertyLoader(StrategizedProperty):
self.direction = sync.ONETOMANY
else:
self.direction = sync.MANYTOONE
-
+
elif len(self.remote_side):
for f in self.foreign_keys:
if f in self.remote_side:
@@ -297,7 +317,7 @@ class PropertyLoader(StrategizedProperty):
if len(self.remote_side):
return
self.remote_side = util.Set()
-
+
if self.direction is sync.MANYTOONE:
for c in self._opposite_side:
self.remote_side.add(c)
@@ -305,18 +325,18 @@ class PropertyLoader(StrategizedProperty):
for c in self.foreign_keys:
self.remote_side.add(c)
- def _create_polymorphic_joins(self):
+ def _create_polymorphic_joins(self):
# get ready to create "polymorphic" primary/secondary join clauses.
# these clauses represent the same join between parent/child tables that the primary
# and secondary join clauses represent, except they reference ColumnElements that are specifically
# in the "polymorphic" selectables. these are used to construct joins for both Query as well as
# eager loading, and also are used to calculate "lazy loading" clauses.
-
- # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
+
+ # as we will be using the polymorphic selectables (i.e. select_table argument to Mapper) to figure this out,
# first create maps of all the "equivalent" columns, since polymorphic selectables will often munge
# several "equivalent" columns (such as parent/child fk cols) into just one column.
target_equivalents = self.mapper._get_inherited_column_equivalents()
-
+
# if the target mapper loads polymorphically, adapt the clauses to the target's selectable
if self.loads_polymorphic:
if self.secondaryjoin:
@@ -350,33 +370,33 @@ class PropertyLoader(StrategizedProperty):
self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys]))
self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side]))
self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many")))
-
+
if self.uselist is None and self.direction is sync.MANYTOONE:
self.uselist = False
if self.uselist is None:
self.uselist = True
-
+
if not self.viewonly:
self._dependency_processor = dependency.create_dependency_processor(self)
-
+
# primary property handler, set up class attributes
if self.is_primary():
- # if a backref name is defined, set up an extension to populate
+ # if a backref name is defined, set up an extension to populate
# attributes in the other direction
if self.backref is not None:
self.attributeext = self.backref.get_extension()
-
+
if self.backref is not None:
self.backref.compile(self)
elif not sessionlib.attribute_manager.is_class_managed(self.parent.class_, self.key):
raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
super(PropertyLoader, self).do_init()
-
+
def _is_self_referential(self):
return self.parent.mapped_table is self.target or self.parent.select_table is self.target
-
+
def get_join(self, parent):
try:
return self._parent_join_cache[parent]
@@ -400,22 +420,27 @@ class PropertyLoader(StrategizedProperty):
j = primaryjoin
self._parent_join_cache[parent] = j
return j
-
+
def register_dependencies(self, uowcommit):
if not self.viewonly:
self._dependency_processor.register_dependencies(uowcommit)
-
+
PropertyLoader.logger = logging.class_logger(PropertyLoader)
class BackRef(object):
- """stores the name of a backreference property as well as options to
- be used on the resulting PropertyLoader."""
+ """Stores the name of a backreference property as well as options
+ to be used on the resulting PropertyLoader.
+ """
+
def __init__(self, key, **kwargs):
self.key = key
self.kwargs = kwargs
+
def compile(self, prop):
- """called by the owning PropertyLoader to set up a backreference on the
- PropertyLoader's mapper."""
+ """Called by the owning PropertyLoader to set up a
+ backreference on the PropertyLoader's mapper.
+ """
+
# try to set a LazyLoader on our mapper referencing the parent mapper
mapper = prop.mapper.primary_mapper()
if not mapper.props.has_key(self.key):
@@ -438,7 +463,8 @@ class BackRef(object):
prop.is_backref=True
if not prop.viewonly:
prop._dependency_processor.is_backref=True
+
def get_extension(self):
- """returns an attribute extension to use with this backreference."""
- return attributes.GenericBackrefExtension(self.key)
+ """Return an attribute extension to use with this backreference."""
+ return attributes.GenericBackrefExtension(self.key)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index dd3196c202..da1354c242 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -11,7 +11,8 @@ from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty
__all__ = ['Query', 'QueryContext', 'SelectionContext']
class Query(object):
- """encapsulates the object-fetching operations provided by Mappers."""
+ """Encapsulates the object-fetching operations provided by Mappers."""
+
def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, extension=None, **kwargs):
if isinstance(class_or_mapper, type):
self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
@@ -36,24 +37,29 @@ class Query(object):
self._get_clause = self.mapper._get_clause
for opt in util.flatten_iterator(self.with_options):
opt.process_query(self)
-
+
def _insert_extension(self, ext):
self.extension.insert(ext)
-
+
def _get_session(self):
if self._session is None:
return self.mapper.get_session()
else:
return self._session
+
table = property(lambda s:s.select_mapper.mapped_table)
primary_key_columns = property(lambda s:s.select_mapper.pks_by_table[s.select_mapper.mapped_table])
session = property(_get_session)
-
+
def get(self, ident, **kwargs):
- """return an instance of the object based on the given identifier, or None if not found.
-
- The ident argument is a scalar or tuple of primary key column values
- in the order of the table def's primary key columns."""
+ """Return an instance of the object based on the given
+ identifier, or None if not found.
+
+ The `ident` argument is a scalar or tuple of primary key
+ column values in the order of the table def's primary key
+ columns.
+ """
+
ret = self.extension.get(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
@@ -61,11 +67,16 @@ class Query(object):
return self._get(key, ident, **kwargs)
def load(self, ident, **kwargs):
- """return an instance of the object based on the given identifier.
-
- If not found, raises an exception. The method will *remove all pending changes* to the object
- already existing in the Session. The ident argument is a scalar or tuple of primary
- key column values in the order of the table def's primary key columns."""
+ """Return an instance of the object based on the given
+ identifier.
+
+ If not found, raises an exception. The method will **remove
+ all pending changes** to the object already existing in the
+ Session. The `ident` argument is a scalar or tuple of primary
+ key column values in the order of the table def's primary key
+ columns.
+ """
+
ret = self.extension.load(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
@@ -74,21 +85,27 @@ class Query(object):
if instance is None:
raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
return instance
-
+
def get_by(self, *args, **params):
- """return a single object instance based on the given key/value criterion.
-
- this is either the first value in the result list, or None if the list is
- empty.
-
- the keys are mapped to property or column names mapped by this mapper's Table, and the values
- are coerced into a WHERE clause separated by AND operators. If the local property/column
- names dont contain the key, a search will be performed against this mapper's immediate
- list of relations as well, forming the appropriate join conditions if a matching property
- is located.
-
- e.g. u = usermapper.get_by(user_name = 'fred')
+ """Return a single object instance based on the given
+ key/value criterion.
+
+ This is either the first value in the result list, or None if
+ the list is empty.
+
+ The keys are mapped to property or column names mapped by this
+ mapper's Table, and the values are coerced into a ``WHERE``
+ clause separated by ``AND`` operators. If the local
+ property/column names dont contain the key, a search will be
+ performed against this mapper's immediate list of relations as
+ well, forming the appropriate join conditions if a matching
+ property is located.
+
+ E.g.::
+
+ u = usermapper.get_by(user_name = 'fred')
"""
+
ret = self.extension.get_by(self, *args, **params)
if ret is not mapper.EXT_PASS:
return ret
@@ -99,30 +116,45 @@ class Query(object):
return None
def select_by(self, *args, **params):
- """return an array of object instances based on the given clauses and key/value criterion.
+ """Return an array of object instances based on the given
+ clauses and key/value criterion.
- *args is a list of zero or more ClauseElements which will be connected by AND operators.
+ `*args` is a list of zero or more ``ClauseElements`` which will be
+ connected by ``AND`` operators.
- **params is a set of zero or more key/value parameters which are converted into ClauseElements.
- the keys are mapped to property or column names mapped by this mapper's Table, and the values
- are coerced into a WHERE clause separated by AND operators. If the local property/column
- names dont contain the key, a search will be performed against this mapper's immediate
- list of relations as well, forming the appropriate join conditions if a matching property
- is located.
+ `**params` is a set of zero or more key/value parameters which
+ are converted into ``ClauseElements``. the keys are mapped to
+ property or column names mapped by this mapper's Table, and
+ the values are coerced into a ``WHERE`` clause separated by
+ ``AND`` operators. If the local property/column names dont
+ contain the key, a search will be performed against this
+ mapper's immediate list of relations as well, forming the
+ appropriate join conditions if a matching property is located.
- e.g. result = usermapper.select_by(user_name = 'fred')
+ E.g.::
+
+ result = usermapper.select_by(user_name = 'fred')
"""
+
ret = self.extension.select_by(self, *args, **params)
if ret is not mapper.EXT_PASS:
return ret
return self.select_whereclause(self.join_by(*args, **params))
def join_by(self, *args, **params):
- """return a ClauseElement representing the WHERE clause that would normally be sent to select_whereclause() by select_by()."""
+ """Return a ``ClauseElement`` representing the ``WHERE``
+ clause that would normally be sent to ``select_whereclause()``
+ by ``select_by()``.
+ """
+
return self._join_by(args, params)
def _join_by(self, args, params, start=None):
- """return a ClauseElement representing the WHERE clause that would normally be sent to select_whereclause() by select_by()."""
+ """Return a ``ClauseElement`` representing the ``WHERE``
+ clause that would normally be sent to ``select_whereclause()``
+ by ``select_by()``.
+ """
+
clause = None
for arg in args:
if clause is None:
@@ -135,7 +167,7 @@ class Query(object):
c = prop.compare(value) & self.join_via(keys)
if clause is None:
clause = c
- else:
+ else:
clause &= c
return clause
@@ -170,17 +202,24 @@ class Query(object):
return [keys, p]
def join_to(self, key):
- """given the key name of a property, will recursively descend through all child properties
- from this Query's mapper to locate the property, and will return a ClauseElement
- representing a join from this Query's mapper to the endmost mapper."""
+ """Given the key name of a property, will recursively descend
+ through all child properties from this Query's mapper to
+ locate the property, and will return a ClauseElement
+ representing a join from this Query's mapper to the endmost
+ mapper.
+ """
+
[keys, p] = self._locate_prop(key)
return self.join_via(keys)
def join_via(self, keys):
- """given a list of keys that represents a path from this Query's mapper to a related mapper
- based on names of relations from one mapper to the next, returns a
- ClauseElement representing a join from this Query's mapper to the endmost mapper.
+ """Given a list of keys that represents a path from this
+ Query's mapper to a related mapper based on names of relations
+ from one mapper to the next, return a ClauseElement
+ representing a join from this Query's mapper to the endmost
+ mapper.
"""
+
mapper = self.mapper
clause = None
for key in keys:
@@ -194,12 +233,18 @@ class Query(object):
return clause
def selectfirst_by(self, *args, **params):
- """works like select_by(), but only returns the first result by itself, or None if no
- objects returned. Synonymous with get_by()"""
+ """Like ``select_by()``, but only return the first result by
+ itself, or None if no objects returned. Synonymous with
+ ``get_by()``.
+ """
+
return self.get_by(*args, **params)
def selectone_by(self, *args, **params):
- """works like selectfirst_by(), but throws an error if not exactly one result was returned."""
+ """Like ``selectfirst_by()``, but throws an error if not
+ exactly one result was returned.
+ """
+
ret = self.select_whereclause(self.join_by(*args, **params), limit=2)
if len(ret) == 1:
return ret[0]
@@ -209,13 +254,20 @@ class Query(object):
raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by')
def count_by(self, *args, **params):
- """returns the count of instances based on the given clauses and key/value criterion.
- The criterion is constructed in the same way as the select_by() method."""
+ """Return the count of instances based on the given clauses
+ and key/value criterion.
+
+ The criterion is constructed in the same way as the
+ ``select_by()`` method.
+ """
+
return self.count(self.join_by(*args, **params))
def selectfirst(self, arg=None, **kwargs):
- """works like select(), but only returns the first result by itself, or None if no
- objects returned."""
+ """Like ``select()``, but only return the first result by
+ itself, or None if no objects returned.
+ """
+
if isinstance(arg, sql.FromClause) and arg.supports_execution():
ret = self.select_statement(arg, **kwargs)
else:
@@ -227,7 +279,10 @@ class Query(object):
return None
def selectone(self, arg=None, **kwargs):
- """works like selectfirst(), but throws an error if not exactly one result was returned."""
+ """Like ``selectfirst()``, but throw an error if not exactly
+ one result was returned.
+ """
+
ret = list(self.select(arg, **kwargs)[0:2])
if len(ret) == 1:
return ret[0]
@@ -237,15 +292,19 @@ class Query(object):
raise exceptions.InvalidRequestError('Multiple rows returned for selectone')
def select(self, arg=None, **kwargs):
- """selects instances of the object from the database.
+ """Select instances of the object from the database.
- arg can be any ClauseElement, which will form the criterion with which to
- load the objects.
+ `arg` can be any ClauseElement, which will form the criterion
+ with which to load the objects.
- For more advanced usage, arg can also be a Select statement object, which
- will be executed and its resulting rowset used to build new object instances.
- in this case, the developer must ensure that an adequate set of columns exists in the
- rowset with which to build new object instances."""
+ For more advanced usage, arg can also be a Select statement
+ object, which will be executed and its resulting rowset used
+ to build new object instances.
+
+ In this case, the developer must ensure that an adequate set
+ of columns exists in the rowset with which to build new object
+ instances.
+ """
ret = self.extension.select(self, arg=arg, **kwargs)
if ret is not mapper.EXT_PASS:
@@ -256,18 +315,23 @@ class Query(object):
return self.select_whereclause(whereclause=arg, **kwargs)
def select_whereclause(self, whereclause=None, params=None, **kwargs):
- """given a WHERE criterion, create a SELECT statement, execute and return the resulting instances."""
+ """Given a ``WHERE`` criterion, create a ``SELECT`` statement,
+ execute and return the resulting instances.
+ """
+
statement = self.compile(whereclause, **kwargs)
return self._select_statement(statement, params=params)
def count(self, whereclause=None, params=None, **kwargs):
- """given a WHERE criterion, create a SELECT COUNT statement, execute and return the resulting count value."""
+ """Given a ``WHERE`` criterion, create a ``SELECT COUNT``
+ statement, execute and return the resulting count value.
+ """
from_obj = kwargs.pop('from_obj', [])
alltables = []
for l in [sql_util.TableFinder(x) for x in from_obj]:
alltables += l
-
+
if self.table not in alltables:
from_obj.append(self.table)
@@ -279,22 +343,32 @@ class Query(object):
return self.session.scalar(self.mapper, s, params=params)
def select_statement(self, statement, **params):
- """given a ClauseElement-based statement, execute and return the resulting instances."""
+ """Given a ``ClauseElement``-based statement, execute and
+ return the resulting instances.
+ """
+
return self._select_statement(statement, params=params)
def select_text(self, text, **params):
- """given a literal string-based statement, execute and return the resulting instances."""
+ """Given a literal string-based statement, execute and return
+ the resulting instances.
+ """
+
t = sql.text(text)
return self.execute(t, params=params)
def options(self, *args, **kwargs):
- """return a new Query object, applying the given list of MapperOptions."""
+ """Return a new Query object, applying the given list of
+ MapperOptions.
+ """
+
return Query(self.mapper, self._session, with_options=args)
-
+
def with_lockmode(self, mode):
- """return a new Query object with the specified locking mode."""
+ """Return a new Query object with the specified locking mode."""
+
return Query(self.mapper, self._session, lockmode=mode)
-
+
def __getattr__(self, key):
if (key.startswith('select_by_')):
key = key[10:]
@@ -310,10 +384,16 @@ class Query(object):
raise AttributeError(key)
def execute(self, clauseelement, params=None, *args, **kwargs):
- """execute the given ClauseElement-based statement against this Query's session/mapper, return the resulting list of instances.
-
- After execution, closes the ResultProxy and its underlying resources.
- This method is one step above the instances() method, which takes the executed statement's ResultProxy directly."""
+ """Execute the given ClauseElement-based statement against
+ this Query's session/mapper, return the resulting list of
+ instances.
+
+ After execution, close the ResultProxy and its underlying
+ resources. This method is one step above the ``instances()``
+ method, which takes the executed statement's ResultProxy
+ directly.
+ """
+
result = self.session.execute(self.mapper, clauseelement, params=params)
try:
return self.instances(result, **kwargs)
@@ -321,11 +401,14 @@ class Query(object):
result.close()
def instances(self, cursor, *mappers, **kwargs):
- """return a list of mapped instances corresponding to the rows in a given "cursor" (i.e. ResultProxy)."""
+ """Return a list of mapped instances corresponding to the rows
+ in a given *cursor* (i.e. ``ResultProxy``).
+ """
+
self.__log_debug("instances()")
session = self.session
-
+
context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
result = util.UniqueAppender([])
@@ -350,7 +433,7 @@ class Query(object):
else:
return result.data
-
+
def _get(self, key, ident=None, reload=False, lockmode=None):
lockmode = lockmode or self.lockmode
if not reload and not self.always_refresh and lockmode is None:
@@ -367,8 +450,8 @@ class Query(object):
params = {}
for primary_key in self.primary_key_columns:
params[primary_key._label] = ident[i]
- # if there are not enough elements in the given identifier, then
- # use the previous identifier repeatedly. this is a workaround for the issue
+ # if there are not enough elements in the given identifier, then
+ # use the previous identifier repeatedly. this is a workaround for the issue
# in [ticket:185], where a mapper that uses joined table inheritance needs to specify
# all primary keys of the joined relationship, which includes even if the join is joining
# two primary key (and therefore synonymous) columns together, the usual case for joined table inheritance.
@@ -387,26 +470,32 @@ class Query(object):
return self.execute(statement, params=params, **kwargs)
def _should_nest(self, querycontext):
- """return True if the given statement options indicate that we should "nest" the
- generated query as a subquery inside of a larger eager-loading query. this is used
- with keywords like distinct, limit and offset and the mapper defines eager loads."""
+ """Return True if the given statement options indicate that we
+ should *nest* the generated query as a subquery inside of a
+ larger eager-loading query. This is used with keywords like
+ distinct, limit and offset and the mapper defines eager loads.
+ """
+
return (
len(querycontext.eager_loaders) > 0
and self._nestable(**querycontext.select_args())
)
def _nestable(self, **kwargs):
- """return true if the given statement options imply it should be nested."""
+ """Return true if the given statement options imply it should be nested."""
+
return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
-
+
def compile(self, whereclause = None, **kwargs):
- """given a WHERE criterion, produce a ClauseElement-based statement suitable for usage in the execute() method."""
-
+ """Given a WHERE criterion, produce a ClauseElement-based
+ statement suitable for usage in the execute() method.
+ """
+
if whereclause is not None and self.is_polymorphic:
# adapt the given WHERECLAUSE to adjust instances of this query's mapped table to be that of our select_table,
# which may be the "polymorphic" selectable used by our mapper.
whereclause.accept_visitor(sql_util.ClauseAdapter(self.table))
-
+
context = kwargs.pop('query_context', None)
if context is None:
context = QueryContext(self, kwargs)
@@ -426,17 +515,17 @@ class Query(object):
for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[lockmode]
except KeyError:
raise exceptions.ArgumentError("Unknown lockmode '%s'" % lockmode)
-
+
if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_(*[m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()]))
-
+
alltables = []
for l in [sql_util.TableFinder(x) for x in from_obj]:
alltables += l
-
+
if self.table not in alltables:
from_obj.append(self.table)
-
+
if self._should_nest(context):
# if theres an order by, add those columns to the column list
# of the "rowcount" query we're going to make
@@ -447,7 +536,7 @@ class Query(object):
o.accept_visitor(cf)
else:
cf = []
-
+
s2 = sql.select(self.table.primary_key + list(cf), whereclause, use_labels=True, from_obj=from_obj, **context.select_args())
if order_by:
s2.order_by(*util.to_list(order_by))
@@ -473,8 +562,8 @@ class Query(object):
# TODO: doing this off the select_mapper. if its the polymorphic mapper, then
# it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads)
for value in self.select_mapper.props.values():
- value.setup(context)
-
+ value.setup(context)
+
return statement
def __log_debug(self, msg):
@@ -483,8 +572,11 @@ class Query(object):
Query.logger = logging.class_logger(Query)
class QueryContext(OperationContext):
- """created within the Query.compile() method to store and share
- state among all the Mappers and MapperProperty objects used in a query construction."""
+ """Created within the ``Query.compile()`` method to store and
+ share state among all the Mappers and MapperProperty objects used
+ in a query construction.
+ """
+
def __init__(self, query, kwargs):
self.query = query
self.order_by = kwargs.pop('order_by', False)
@@ -496,38 +588,53 @@ class QueryContext(OperationContext):
self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
self.statement = None
super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs)
+
def select_args(self):
- """return a dictionary of attributes from this QueryContext that can be applied to a sql.Select statement."""
+ """Return a dictionary of attributes from this
+ ``QueryContext`` that can be applied to a ``sql.Select``
+ statement.
+ """
return {'limit':self.limit, 'offset':self.offset, 'distinct':self.distinct}
+
def accept_option(self, opt):
- """accept a MapperOption which will process (modify) the state of this QueryContext."""
+ """Accept a ``MapperOption`` which will process (modify) the
+ state of this ``QueryContext``.
+ """
+
opt.process_query_context(self)
class SelectionContext(OperationContext):
- """created within the query.instances() method to store and share
- state among all the Mappers and MapperProperty objects used in a load operation.
+ """Created within the ``query.instances()`` method to store and share
+ state among all the Mappers and MapperProperty objects used in a
+ load operation.
- SelectionContext contains these attributes:
+ SelectionContext contains these attributes::
- mapper - the Mapper which originated the instances() call.
+ mapper
+ The Mapper which originated the instances() call.
- session - the Session that is relevant to the instances call.
+ session
+ The Session that is relevant to the instances call.
- identity_map - a dictionary which stores newly created instances that have
- not yet been added as persistent to the Session.
+ identity_map
+ A dictionary which stores newly created instances that have not
+ yet been added as persistent to the Session.
- attributes - a dictionary to store arbitrary data; eager loaders use it to
- store additional result lists
+ attributes
+ A dictionary to store arbitrary data; eager loaders use it to
+ store additional result lists.
- populate_existing - indicates if its OK to overwrite the attributes of instances
- that were already in the Session
-
- version_check - indicates if mappers that have version_id columns should verify
- that instances existing already within the Session should have this attribute compared
- to the freshly loaded value
+ populate_existing
+ Indicates if its OK to overwrite the attributes of instances
+ that were already in the Session.
+ version_check
+ Indicates if mappers that have version_id columns should verify
+ that instances existing already within the Session should have
+ this attribute compared to the freshly loaded value.
"""
+
def __init__(self, mapper, session, extension, **kwargs):
self.populate_existing = kwargs.pop('populate_existing', False)
self.version_check = kwargs.pop('version_check', False)
@@ -535,7 +642,10 @@ class SelectionContext(OperationContext):
self.extension = extension
self.identity_map = {}
super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs)
+
def accept_option(self, opt):
- """accept a MapperOption which will process (modify) the state of this SelectionContext."""
+ """Accept a MapperOption which will process (modify) the state
+ of this SelectionContext.
+ """
+
opt.process_selection_context(self)
-
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 231be29b8b..a6e1e9ee25 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -11,18 +11,22 @@ from sqlalchemy.orm.mapper import class_mapper as _class_mapper
import weakref
import sqlalchemy
-
class SessionTransaction(object):
- """represents a Session-level Transaction. This corresponds to one or
- more sqlalchemy.engine.Transaction instances behind the scenes, with one
- Transaction per Engine in use.
-
- the SessionTransaction object is **not** threadsafe."""
+ """Represents a Session-level Transaction.
+
+ This corresponds to one or more sqlalchemy.engine.Transaction
+ instances behind the scenes, with one Transaction per Engine in
+ use.
+
+ The SessionTransaction object is **not** threadsafe.
+ """
+
def __init__(self, session, parent=None, autoflush=True):
self.session = session
self.connections = {}
self.parent = parent
self.autoflush = autoflush
+
def connection(self, mapper_or_class, entity_name=None):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
@@ -30,14 +34,17 @@ class SessionTransaction(object):
return self.parent.connection(mapper_or_class)
engine = self.session.get_bind(mapper_or_class)
return self.get_or_add(engine)
+
def _begin(self):
return SessionTransaction(self.session, self)
+
def add(self, connectable):
if self.connections.has_key(connectable.engine):
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
return self.get_or_add(connectable)
+
def get_or_add(self, connectable):
- # we reference the 'engine' attribute on the given object, which in the case of
+ # we reference the 'engine' attribute on the given object, which in the case of
# Connection, ProxyEngine, Engine, whatever, should return the original
# "Engine" object that is handling the connection.
if self.connections.has_key(connectable.engine):
@@ -47,6 +54,7 @@ class SessionTransaction(object):
if not self.connections.has_key(e):
self.connections[e] = (c, c.begin(), c is not connectable)
return self.connections[e][0]
+
def commit(self):
if self.parent is not None:
return
@@ -55,6 +63,7 @@ class SessionTransaction(object):
for t in self.connections.values():
t[1].commit()
self.close()
+
def rollback(self):
if self.parent is not None:
self.parent.rollback()
@@ -62,6 +71,7 @@ class SessionTransaction(object):
for k, t in self.connections.iteritems():
t[1].rollback()
self.close()
+
def close(self):
if self.parent is not None:
return
@@ -69,8 +79,10 @@ class SessionTransaction(object):
if t[2]:
t[0].close()
self.session.transaction = None
+
def __enter__(self):
return self
+
def __exit__(self, type, value, traceback):
if self.session.transaction is None:
return
@@ -80,16 +92,19 @@ class SessionTransaction(object):
self.rollback()
class Session(object):
- """encapsulates a set of objects being operated upon within an object-relational operation.
-
- The Session object is **not** threadsafe. For thread-management of Sessions, see the
- sqlalchemy.ext.sessioncontext module."""
+ """Encapsulates a set of objects being operated upon within an
+ object-relational operation.
+
+ The Session object is **not** threadsafe. For thread-management
+ of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module.
+ """
+
def __init__(self, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False):
if import_session is not None:
self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map)
else:
self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
-
+
self.bind_to = bind_to
self.binds = {}
self.echo_uow = echo_uow
@@ -103,96 +118,147 @@ class Session(object):
def _get_echo_uow(self):
return self.uow.echo
+
def _set_echo_uow(self, value):
self.uow.echo = value
echo_uow = property(_get_echo_uow,_set_echo_uow)
-
+
def create_transaction(self, **kwargs):
- """returns a new SessionTransaction corresponding to an existing or new transaction.
- if the transaction is new, the returned SessionTransaction will have commit control
- over the underlying transaction, else will have rollback control only."""
+ """Return a new SessionTransaction corresponding to an
+ existing or new transaction.
+
+ If the transaction is new, the returned SessionTransaction
+ will have commit control over the underlying transaction, else
+ will have rollback control only.
+ """
+
if self.transaction is not None:
return self.transaction._begin()
else:
self.transaction = SessionTransaction(self, **kwargs)
return self.transaction
+
def connect(self, mapper=None, **kwargs):
- """returns a unique connection corresponding to the given mapper. this connection
- will not be part of any pre-existing transactional context."""
+ """Return a unique connection corresponding to the given
+ mapper.
+
+ This connection will not be part of any pre-existing
+ transactional context.
+ """
+
return self.get_bind(mapper).connect(**kwargs)
+
def connection(self, mapper, **kwargs):
- """returns a Connection corresponding to the given mapper. used by the execute()
- method which performs select operations for Mapper and Query.
- if this Session is transactional,
- the connection will be in the context of this session's transaction. otherwise, the connection
- is returned by the contextual_connect method, which some Engines override to return a thread-local
- connection, and will have close_with_result set to True.
-
- the given **kwargs will be sent to the engine's contextual_connect() method, if no transaction is in progress."""
+ """Return a Connection corresponding to the given mapper.
+
+ Used by the ``execute()`` method which performs select
+ operations for Mapper and Query.
+
+ If this Session is transactional, the connection will be in
+ the context of this session's transaction. Otherwise, the
+ connection is returned by the contextual_connect method, which
+ some Engines override to return a thread-local connection, and
+ will have close_with_result set to True.
+
+ The given `**kwargs` will be sent to the engine's
+ ``contextual_connect()`` method, if no transaction is in
+ progress.
+ """
+
if self.transaction is not None:
return self.transaction.connection(mapper)
else:
return self.get_bind(mapper).contextual_connect(**kwargs)
+
def execute(self, mapper, clause, params, **kwargs):
- """using the given mapper to identify the appropriate Engine or Connection to be used for statement execution,
- executes the given ClauseElement using the provided parameter dictionary. Returns a ResultProxy corresponding
- to the execution's results. If this method allocates a new Connection for the operation, then the ResultProxy's close()
- method will release the resources of the underlying Connection, otherwise its a no-op.
+ """Using the given mapper to identify the appropriate Engine
+ or Connection to be used for statement execution, execute the
+ given ClauseElement using the provided parameter dictionary.
+
+ Return a ResultProxy corresponding to the execution's results.
+
+ If this method allocates a new Connection for the operation,
+ then the ResultProxy's ``close()`` method will release the
+ resources of the underlying Connection, otherwise its a no-op.
"""
return self.connection(mapper, close_with_result=True).execute(clause, params, **kwargs)
+
def scalar(self, mapper, clause, params, **kwargs):
- """works like execute() but returns a scalar result."""
+ """Like execute() but return a scalar result."""
+
return self.connection(mapper, close_with_result=True).scalar(clause, params, **kwargs)
-
+
def close(self):
- """closes this Session.
- """
+ """Close this Session."""
+
self.clear()
if self.transaction is not None:
self.transaction.close()
def clear(self):
- """removes all object instances from this Session. this is equivalent to calling expunge() for all
- objects in this Session."""
+ """Remove all object instances from this Session.
+
+ This is equivalent to calling ``expunge()`` for all objects in
+ this Session.
+ """
+
for instance in self:
self._unattach(instance)
echo = self.uow.echo
self.uow = unitofwork.UnitOfWork(weak_identity_map=self.weak_identity_map)
self.uow.echo = echo
-
+
def mapper(self, class_, entity_name=None):
- """given an Class, return the primary Mapper responsible for persisting it"""
+ """Given an Class, return the primary Mapper responsible for
+ persisting it."""
+
return _class_mapper(class_, entity_name = entity_name)
+
def bind_mapper(self, mapper, bindto):
- """bind the given Mapper to the given Engine or Connection.
-
- All subsequent operations involving this Mapper will use the given bindto."""
+ """Bind the given `Mapper` to the given Engine or Connection.
+
+ All subsequent operations involving this Mapper will use the
+ given `bindto`.
+ """
+
self.binds[mapper] = bindto
+
def bind_table(self, table, bindto):
- """bind the given Table to the given Engine or Connection.
-
- All subsequent operations involving this Table will use the given bindto."""
+ """Bind the given `Table` to the given Engine or Connection.
+
+ All subsequent operations involving this Table will use the
+ given `bindto`.
+ """
+
self.binds[table] = bindto
+
def get_bind(self, mapper):
- """return the Engine or Connection which is used to execute statements on behalf of the given Mapper.
-
- Calling connect() on the return result will always result in a Connection object. This method
- disregards any SessionTransaction that may be in progress.
-
+ """Return the Engine or Connection which is used to execute
+ statements on behalf of the given `Mapper`.
+
+ Calling ``connect()`` on the return result will always result
+ in a Connection object. This method disregards any
+ SessionTransaction that may be in progress.
+
The order of searching is as follows:
-
- if an Engine or Connection was bound to this Mapper specifically within this Session, returns that
- Engine or Connection.
-
- if an Engine or Connection was bound to this Mapper's underlying Table within this Session
- (i.e. not to the Table directly), returns that Engine or Conneciton.
-
- if an Engine or Connection was bound to this Session, returns that Engine or Connection.
-
- finally, returns the Engine which was bound directly to the Table's MetaData object.
-
+
+ 1. if an Engine or Connection was bound to this Mapper
+ specifically within this Session, return that Engine or
+ Connection.
+
+ 2. if an Engine or Connection was bound to this Mapper's
+ underlying Table within this Session (i.e. not to the Table
+ directly), return that Engine or Conneciton.
+
+ 3. if an Engine or Connection was bound to this Session,
+ return that Engine or Connection.
+
+ 4. finally, return the Engine which was bound directly to the
+ Table's MetaData object.
+
If no Engine is bound to the Table, an exception is raised.
"""
+
if mapper is None:
return self.bind_to
elif self.binds.has_key(mapper):
@@ -275,75 +341,94 @@ class Session(object):
attribute_manager.trigger_history(obj, exp)
def is_expired(self, obj, unexpire=False):
- """return True if the given object has been marked as expired."""
+ """Return True if the given object has been marked as expired."""
+
ret = attribute_manager.has_trigger(obj)
if ret and unexpire:
attribute_manager.untrigger_history(obj)
return ret
def expunge(self, object):
- """remove the given object from this Session.
-
- this will free all internal references to the object. cascading will be applied according to the
- 'expunge' cascade rule."""
+ """Remove the given object from this Session.
+
+ This will free all internal references to the object.
+ Cascading will be applied according to the *expunge* cascade
+ rule.
+ """
+
for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)):
self.uow._remove_deleted(c)
self._unattach(c)
-
+
def save(self, object, entity_name=None):
+ """Add a transient (unsaved) instance to this Session.
+
+ This operation cascades the `save_or_update` method to
+ associated instances if the relation is mapped with
+ ``cascade="save-update"``.
+
+ The `entity_name` keyword argument will further qualify the
+ specific Mapper used to handle this instance.
"""
- Add a transient (unsaved) instance to this Session.
-
- This operation cascades the "save_or_update" method to associated instances if the
- relation is mapped with cascade="save-update".
-
- The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this
- instance.
- """
+
self._save_impl(object, entity_name=entity_name)
_object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self)
def update(self, object, entity_name=None):
- """Bring the given detached (saved) instance into this Session.
-
- If there is a persistent instance with the same identifier already associated
- with this Session, an exception is thrown.
+ """Bring the given detached (saved) instance into this
+ Session.
+
+ If there is a persistent instance with the same identifier
+ already associated with this Session, an exception is thrown.
+
+ This operation cascades the `save_or_update` method to
+ associated instances if the relation is mapped with
+ ``cascade="save-update"``.
+ """
- This operation cascades the "save_or_update" method to associated instances if the relation is mapped
- with cascade="save-update"."""
self._update_impl(object, entity_name=entity_name)
_object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self)
def save_or_update(self, object, entity_name=None):
- """save or update the given object into this Session.
-
- The presence of an '_instance_key' attribute on the instance determines whether to
- save() or update() the instance."""
+ """Save or update the given object into this Session.
+
+ The presence of an `_instance_key` attribute on the instance
+ determines whether to ``save()`` or ``update()`` the instance.
+ """
+
self._save_or_update_impl(object, entity_name=entity_name)
_object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self)
-
+
def _save_or_update_impl(self, object, entity_name=None):
key = getattr(object, '_instance_key', None)
if key is None:
self._save_impl(object, entity_name=entity_name)
else:
self._update_impl(object, entity_name=entity_name)
-
+
def delete(self, object, entity_name=None):
- """mark the given instance as deleted.
-
- the delete operation occurs upon flush()."""
+ """Mark the given instance as deleted.
+
+ The delete operation occurs upon ``flush()``.
+ """
+
for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
self.uow.register_deleted(c)
def merge(self, object, entity_name=None, _recursive=None):
- """copy the state of the given object onto the persistent object with the same identifier.
-
- If there is no persistent instance currently associated with the session, it will be loaded.
- Return the persistent instance. If the given instance is unsaved, save a copy of and return it as
- a newly persistent instance. The given instance does not become associated with the session.
- This operation cascades to associated instances if the association is mapped with cascade="merge".
+ """Copy the state of the given object onto the persistent
+ object with the same identifier.
+
+ If there is no persistent instance currently associated with
+ the session, it will be loaded. Return the persistent
+ instance. If the given instance is unsaved, save a copy of and
+ return it as a newly persistent instance. The given instance
+ does not become associated with the session.
+
+ This operation cascades to associated instances if the
+ association is mapped with ``cascade="merge"``.
"""
+
if _recursive is None:
_recursive = util.Set()
mapper = _object_mapper(object)
@@ -360,19 +445,19 @@ class Session(object):
if key is None:
self.save(merged)
return merged
-
+
def _save_impl(self, object, **kwargs):
if hasattr(object, '_instance_key'):
if not self.identity_map.has_key(object._instance_key):
raise exceptions.InvalidRequestError("Instance '%s' is a detached instance or is already persistent in a different Session" % repr(object))
else:
m = _class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None))
-
- # this would be a nice exception to raise...however this is incompatible with a contextual
+
+ # this would be a nice exception to raise...however this is incompatible with a contextual
# session which puts all objects into the session upon construction.
#if m._is_orphan(object):
# raise exceptions.InvalidRequestError("Instance '%s' is an orphan, and must be attached to a parent object to be saved" % (repr(object)))
-
+
m._assign_entity_name(object)
self._register_pending(object)
@@ -386,21 +471,24 @@ class Session(object):
def _register_pending(self, obj):
self._attach(obj)
self.uow.register_new(obj)
+
def _register_persistent(self, obj):
self._attach(obj)
self.uow.register_clean(obj)
+
def _register_deleted(self, obj):
self._attach(obj)
self.uow.register_deleted(obj)
-
+
def _attach(self, obj):
"""Attach the given object to this Session."""
+
if getattr(obj, '_sa_session_id', None) != self.hash_key:
old = getattr(obj, '_sa_session_id', None)
if old is not None and _sessions.has_key(old):
raise exceptions.InvalidRequestError("Object '%s' is already attached to session '%s' (this is '%s')" % (repr(obj), old, id(self)))
-
- # auto-removal from the old session is disabled. but if we decide to
+
+ # auto-removal from the old session is disabled. but if we decide to
# turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
# and it might be affected by other threads
#try:
@@ -413,51 +501,67 @@ class Session(object):
if key is not None:
self.identity_map[key] = obj
obj._sa_session_id = self.hash_key
-
+
def _unattach(self, obj):
self._validate_attached(obj)
del obj._sa_session_id
-
+
def _validate_attached(self, obj):
- """validate that the given object is either pending or persistent within this Session."""
+ """Validate that the given object is either pending or
+ persistent within this Session.
+ """
+
if not self._is_attached(obj):
raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % repr(obj))
+
def _validate_persistent(self, obj):
- """validate that the given object is persistent within this Session."""
+ """Validate that the given object is persistent within this
+ Session.
+ """
+
self.uow._validate_obj(obj)
+
def _is_attached(self, obj):
return getattr(obj, '_sa_session_id', None) == self.hash_key
+
def __contains__(self, obj):
return self._is_attached(obj) and (obj in self.uow.new or self.identity_map.has_key(obj._instance_key))
+
def __iter__(self):
return iter(list(self.uow.new) + self.uow.identity_map.values())
+
def _get(self, key):
return self.identity_map[key]
+
def has_key(self, key):
return self.identity_map.has_key(key)
-
- dirty = property(lambda s:s.uow.locate_dirty(), doc="a Set of all objects marked as 'dirty' within this Session")
- deleted = property(lambda s:s.uow.deleted, doc="a Set of all objects marked as 'deleted' within this Session")
- new = property(lambda s:s.uow.new, doc="a Set of all objects marked as 'new' within this Session.")
- identity_map = property(lambda s:s.uow.identity_map, doc="a dictionary consisting of all objects within this Session keyed to their _instance_key value.")
-
+
+ dirty = property(lambda s:s.uow.locate_dirty(),
+ doc="A Set of all objects marked as 'dirty' within this Session")
+ deleted = property(lambda s:s.uow.deleted,
+ doc="A Set of all objects marked as 'deleted' within this Session")
+ new = property(lambda s:s.uow.new,
+ doc="A Set of all objects marked as 'new' within this Session.")
+ identity_map = property(lambda s:s.uow.identity_map,
+ doc="A dictionary consisting of all objects within this Session keyed to their _instance_key value.")
+
def import_instance(self, *args, **kwargs):
- """deprecated; a synynom for merge()"""
+ """Deprecated. A synynom for ``merge()``."""
return self.merge(*args, **kwargs)
-
# this is the AttributeManager instance used to provide attribute behavior on objects.
# to all the "global variable police" out there: its a stateless object.
attribute_manager = unitofwork.attribute_manager
-# this dictionary maps the hash key of a Session to the Session itself, and
+# this dictionary maps the hash key of a Session to the Session itself, and
# acts as a Registry with which to locate Sessions. this is to enable
# object instances to be associated with Sessions without having to attach the
# actual Session object directly to the object instance.
_sessions = weakref.WeakValueDictionary()
def object_session(obj):
- """return the Session to which the given object is bound, or None if none."""
+ """Return the Session to which the given object is bound, or None if none."""
+
hashkey = getattr(obj, '_sa_session_id', None)
if hashkey is not None:
return _sessions.get(hashkey)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index ef96b03976..115b53bfd0 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -41,8 +41,12 @@ class ColumnLoader(LoaderStrategy):
ColumnLoader.logger = logging.class_logger(ColumnLoader)
class DeferredColumnLoader(LoaderStrategy):
- """describes an object attribute that corresponds to a table column, which also
- will "lazy load" its value from the table. this is per-column lazy loading."""
+ """Describes an object attribute that corresponds to a table
+ column, which also will *lazy load* its value from the table.
+
+ This is per-column lazy loading.
+ """
+
def init(self):
super(DeferredColumnLoader, self).init()
self.columns = self.parent_property.columns
@@ -120,6 +124,7 @@ class DeferredOption(StrategizedOption):
def __init__(self, key, defer=False):
super(DeferredOption, self).__init__(key)
self.defer = defer
+
def get_strategy_class(self):
if self.defer:
return DeferredColumnLoader
@@ -143,6 +148,7 @@ class AbstractRelationLoader(LoaderStrategy):
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
self._register_attribute(self.parent.class_)
+
def process_row(self, selectcontext, instance, row, identitykey, isnew):
if isnew:
if not self.is_default or len(selectcontext.options):
@@ -266,6 +272,7 @@ class LazyLoader(AbstractRelationLoader):
def bind_label():
return "lazy_" + hex(random.randint(0, 65535))[2:]
+
def visit_binary(binary):
leftcol = find_column_in_expr(binary.left)
rightcol = find_column_in_expr(binary.right)
@@ -299,9 +306,9 @@ class LazyLoader(AbstractRelationLoader):
LazyLoader.logger = logging.class_logger(LazyLoader)
-
class EagerLoader(AbstractRelationLoader):
- """loads related objects inline with a parent query."""
+ """Loads related objects inline with a parent query."""
+
def init(self):
super(EagerLoader, self).init()
if self.parent.isa(self.mapper):
@@ -312,15 +319,18 @@ class EagerLoader(AbstractRelationLoader):
self.clauses_by_lead_mapper = {}
class AliasedClauses(object):
- """defines a set of join conditions and table aliases which are aliased on a randomly-generated
- alias name, corresponding to the connection of an optional parent AliasedClauses object and a
- target mapper.
+ """Defines a set of join conditions and table aliases which
+ are aliased on a randomly-generated alias name, corresponding
+ to the connection of an optional parent AliasedClauses object
+ and a target mapper.
- EagerLoader has a distinct AliasedClauses object per parent AliasedClauses object,
- so that all paths from one mapper to another across a chain of eagerloaders generates a distinct
- chain of joins. The AliasedClauses objects are generated and cached on an as-needed basis.
+ EagerLoader has a distinct AliasedClauses object per parent
+ AliasedClauses object, so that all paths from one mapper to
+ another across a chain of eagerloaders generates a distinct
+ chain of joins. The AliasedClauses objects are generated and
+ cached on an as-needed basis.
- e.g.:
+ E.g.::
mapper A -->
(EagerLoader 'items') -->
@@ -328,7 +338,7 @@ class EagerLoader(AbstractRelationLoader):
(EagerLoader 'keywords') -->
mapper C
- will generate:
+ will generate::
EagerLoader 'items' --> {
None : AliasedClauses(items, None, alias_suffix='AB34') # mappera JOIN mapperb_AB34
@@ -340,6 +350,7 @@ class EagerLoader(AbstractRelationLoader):
AliasedClauses(keywords, items, alias_suffix='8F44') # mapperb_AB34 JOIN mapperc_8F44
]
"""
+
def __init__(self, eagerloader, parentclauses=None):
self.parent = eagerloader
self.target = eagerloader.select_table
@@ -408,7 +419,8 @@ class EagerLoader(AbstractRelationLoader):
self.parent_property._get_strategy(LazyLoader).init_class_attribute()
def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs):
- """add a left outer join to the statement thats being constructed"""
+ """Add a left outer join to the statement thats being constructed."""
+
if parentmapper is None:
localparent = context.mapper
else:
@@ -468,9 +480,13 @@ class EagerLoader(AbstractRelationLoader):
value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper)
def _create_row_processor(self, selectcontext, row):
- """create a 'row processing' function that will apply eager aliasing to the row.
+ """Create a *row processing* function that will apply eager
+ aliasing to the row.
+
+ Also check that an identity key can be retrieved from the row,
+ else return None.
+ """
- also check that an identity key can be retrieved from the row, else return None."""
# check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option)
if selectcontext.attributes.has_key((EagerLoader, self.parent_property)):
# custom row decoration function, placed in the selectcontext by the
@@ -501,8 +517,12 @@ class EagerLoader(AbstractRelationLoader):
return None
def process_row(self, selectcontext, instance, row, identitykey, isnew):
- """receive a row. tell our mapper to look for a new object instance in the row, and attach
- it to a list on the parent instance."""
+ """Receive a row.
+
+ Tell our mapper to look for a new object instance in the row,
+ and attach it to a list on the parent instance.
+ """
+
if self in selectcontext.recursion_stack:
return
@@ -563,6 +583,7 @@ class EagerLazyOption(StrategizedOption):
def __init__(self, key, lazy=True):
super(EagerLazyOption, self).__init__(key)
self.lazy = lazy
+
def process_query_property(self, context, prop):
if self.lazy:
if prop in context.eager_loaders:
@@ -570,6 +591,7 @@ class EagerLazyOption(StrategizedOption):
else:
context.eager_loaders.add(prop)
super(EagerLazyOption, self).process_query_property(context, prop)
+
def get_strategy_class(self):
if self.lazy:
return LazyLoader
@@ -577,6 +599,7 @@ class EagerLazyOption(StrategizedOption):
return EagerLoader
elif self.lazy is None:
return NoLoader
+
EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
class RowDecorateOption(PropertyOption):
@@ -584,6 +607,7 @@ class RowDecorateOption(PropertyOption):
super(RowDecorateOption, self).__init__(key)
self.decorator = decorator
self.alias = alias
+
def process_selection_property(self, context, property):
if self.alias is not None and self.decorator is None:
if isinstance(self.alias, basestring):
@@ -595,6 +619,7 @@ class RowDecorateOption(PropertyOption):
return d
self.decorator = decorate
context.attributes[(EagerLoader, property)] = self.decorator
+
RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
index 4b9b0c35ce..68fa9cee16 100644
--- a/lib/sqlalchemy/orm/sync.py
+++ b/lib/sqlalchemy/orm/sync.py
@@ -10,21 +10,30 @@ from sqlalchemy import sql, schema, exceptions
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
-"""contains the ClauseSynchronizer class, which is used to map attributes between two objects
-in a manner corresponding to a SQL clause that compares column values."""
+"""Contains the ClauseSynchronizer class, which is used to map
+attributes between two objects in a manner corresponding to a SQL
+clause that compares column values.
+"""
ONETOMANY = 0
MANYTOONE = 1
MANYTOMANY = 2
class ClauseSynchronizer(object):
- """Given a SQL clause, usually a series of one or more binary
- expressions between columns, and a set of 'source' and 'destination' mappers, compiles a set of SyncRules
- corresponding to that information. The ClauseSynchronizer can then be executed given a set of parent/child
- objects or destination dictionary, which will iterate through each of its SyncRules and execute them.
- Each SyncRule will copy the value of a single attribute from the parent
- to the child, corresponding to the pair of columns in a particular binary expression, using the source and
- destination mappers to map those two columns to object attributes within parent and child."""
+ """Given a SQL clause, usually a series of one or more binary
+ expressions between columns, and a set of 'source' and
+ 'destination' mappers, compiles a set of SyncRules corresponding
+ to that information.
+
+ The ClauseSynchronizer can then be executed given a set of
+ parent/child objects or destination dictionary, which will iterate
+ through each of its SyncRules and execute them. Each SyncRule
+ will copy the value of a single attribute from the parent to the
+ child, corresponding to the pair of columns in a particular binary
+ expression, using the source and destination mappers to map those
+ two columns to object attributes within parent and child.
+ """
+
def __init__(self, parent_mapper, child_mapper, direction):
self.parent_mapper = parent_mapper
self.child_mapper = child_mapper
@@ -33,17 +42,18 @@ class ClauseSynchronizer(object):
def compile(self, sqlclause, foreign_keys=None, issecondary=None):
def compile_binary(binary):
- """assemble a SyncRule given a single binary condition"""
+ """Assemble a SyncRule given a single binary condition."""
+
if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
return
source_column = None
dest_column = None
-
+
if foreign_keys is None:
if binary.left.table == binary.right.table:
raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")
-
+
if binary.left in [f.column for f in binary.right.foreign_keys]:
dest_column = binary.right
source_column = binary.left
@@ -57,8 +67,8 @@ class ClauseSynchronizer(object):
elif binary.right in foreign_keys:
source_column = binary.left
dest_column = binary.right
-
- if source_column and dest_column:
+
+ if source_column and dest_column:
if self.direction == ONETOMANY:
self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper))
elif self.direction == MANYTOONE:
@@ -74,40 +84,45 @@ class ClauseSynchronizer(object):
sqlclause.accept_visitor(processor)
if len(self.syncrules) == rules_added:
raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
-
+
def dest_columns(self):
return [r.dest_column for r in self.syncrules if r.dest_column is not None]
def execute(self, source, dest, obj=None, child=None, clearkeys=None):
for rule in self.syncrules:
rule.execute(source, dest, obj, child, clearkeys)
-
+
class SyncRule(object):
- """An instruction indicating how to populate the objects on each side of a relationship.
- i.e. if table1 column A is joined against
- table2 column B, and we are a one-to-many from table1 to table2, a syncrule would say
- 'take the A attribute from object1 and assign it to the B attribute on object2'.
-
- A rule contains the source mapper, the source column, destination column,
- destination mapper in the case of a one/many relationship, and
- the integer direction of this mapper relative to the association in the case
- of a many to many relationship.
+ """An instruction indicating how to populate the objects on each
+ side of a relationship.
+
+ In other words, if table1 column A is joined against table2 column
+ B, and we are a one-to-many from table1 to table2, a syncrule
+ would say *take the A attribute from object1 and assign it to the
+ B attribute on object2*.
+
+ A rule contains the source mapper, the source column, destination
+ column, destination mapper in the case of a one/many relationship,
+ and the integer direction of this mapper relative to the
+ association in the case of a many to many relationship.
"""
+
def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None):
self.source_mapper = source_mapper
self.source_column = source_column
self.issecondary = issecondary
self.dest_mapper = dest_mapper
self.dest_column = dest_column
-
+
#print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
+
def dest_primary_key(self):
try:
return self._dest_primary_key
except AttributeError:
self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table]
return self._dest_primary_key
-
+
def execute(self, source, dest, obj, child, clearkeys):
if source is None:
if self.issecondary is False:
@@ -124,16 +139,16 @@ class SyncRule(object):
else:
if clearkeys and self.dest_primary_key():
raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.instance_str(dest)))
-
+
if logging.is_debug_enabled(self.logger):
self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
self.dest_mapper.set_attr_by_column(dest, self.dest_column, value)
SyncRule.logger = logging.class_logger(SyncRule)
-
+
class BinaryVisitor(sql.ClauseVisitor):
def __init__(self, func):
self.func = func
+
def visit_binary(self, binary):
self.func(binary)
-
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
index e77f583560..587c405e65 100644
--- a/lib/sqlalchemy/orm/unitofwork.py
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -4,14 +4,19 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""the internals for the Unit Of Work system. includes hooks into the attributes package
-enabling the routing of change events to Unit Of Work objects, as well as the flush() mechanism
-which creates a dependency structure that executes change operations.
-
-a Unit of Work is essentially a system of maintaining a graph of in-memory objects and their
-modified state. Objects are maintained as unique against their primary key identity using
-an "identity map" pattern. The Unit of Work then maintains lists of objects that are new,
-dirty, or deleted and provides the capability to flush all those changes at once.
+"""The internals for the Unit Of Work system.
+
+Includes hooks into the attributes package enabling the routing of
+change events to Unit Of Work objects, as well as the flush()
+mechanism which creates a dependency structure that executes change
+operations.
+
+A Unit of Work is essentially a system of maintaining a graph of
+in-memory objects and their modified state. Objects are maintained as
+unique against their primary key identity using an *identity map*
+pattern. The Unit of Work then maintains lists of objects that are
+new, dirty, or deleted and provides the capability to flush all those
+changes at once.
"""
from sqlalchemy import util, logging, topological
@@ -23,11 +28,15 @@ import weakref
import sets
class UOWEventHandler(attributes.AttributeExtension):
- """an event handler added to all class attributes which handles session operations."""
+ """An event handler added to all class attributes which handles
+ session operations.
+ """
+
def __init__(self, key, class_, cascade=None):
self.key = key
self.class_ = class_
self.cascade = cascade
+
def append(self, event, obj, item):
# process "save_update" cascade rules for when an instance is appended to the list of another instance
sess = object_session(obj)
@@ -54,25 +63,35 @@ class UOWEventHandler(attributes.AttributeExtension):
sess.save_or_update(newvalue, entity_name=ename)
class UOWProperty(attributes.InstrumentedAttribute):
- """override InstrumentedAttribute to provide an extra AttributeExtension to all managed attributes
- as well as the 'property' property."""
+ """Override ``InstrumentedAttribute`` to provide an extra
+ ``AttributeExtension`` to all managed attributes as well as the
+ `property` property.
+ """
+
def __init__(self, manager, class_, key, uselist, callable_, typecallable, cascade=None, extension=None, **kwargs):
extension = util.to_list(extension or [])
extension.insert(0, UOWEventHandler(key, class_, cascade=cascade))
super(UOWProperty, self).__init__(manager, key, uselist, callable_, typecallable, extension=extension,**kwargs)
self.class_ = class_
-
+
property = property(lambda s:class_mapper(s.class_).props[s.key], doc="returns the MapperProperty object associated with this property")
-
+
class UOWAttributeManager(attributes.AttributeManager):
- """override AttributeManager to provide the UOWProperty instance for all InstrumentedAttributes."""
+ """Override ``AttributeManager`` to provide the ``UOWProperty``
+ instance for all ``InstrumentedAttributes``.
+ """
+
def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
return UOWProperty(self, class_, key, uselist, callable_, typecallable, **kwargs)
class UnitOfWork(object):
- """main UOW object which stores lists of dirty/new/deleted objects.
- provides top-level "flush" functionality as well as the transaction
- boundaries with the SQLEngine(s) involved in a write operation."""
+ """Main UOW object which stores lists of dirty/new/deleted objects.
+
+ Provides top-level *flush* functionality as well as the
+ transaction boundaries with the SQLEngine(s) involved in a write
+ operation.
+ """
+
def __init__(self, identity_map=None, weak_identity_map=False):
if identity_map is not None:
self.identity_map = identity_map
@@ -81,7 +100,7 @@ class UnitOfWork(object):
self.identity_map = weakref.WeakValueDictionary()
else:
self.identity_map = {}
-
+
self.new = util.Set() #OrderedSet()
self.deleted = util.Set()
self.logger = logging.instance_logger(self)
@@ -91,7 +110,7 @@ class UnitOfWork(object):
def _remove_deleted(self, obj):
if hasattr(obj, "_instance_key"):
del self.identity_map[obj._instance_key]
- try:
+ try:
self.deleted.remove(obj)
except KeyError:
pass
@@ -104,20 +123,20 @@ class UnitOfWork(object):
if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \
(not hasattr(obj, '_instance_key') and obj not in self.new):
raise InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
-
+
def _is_valid(self, obj):
if (hasattr(obj, '_instance_key') and not self.identity_map.has_key(obj._instance_key)) or \
(not hasattr(obj, '_instance_key') and obj not in self.new):
return False
else:
return True
-
+
def register_attribute(self, class_, key, uselist, **kwargs):
attribute_manager.register_attribute(class_, key, uselist, **kwargs)
def register_callable(self, obj, key, func, uselist, **kwargs):
attribute_manager.set_callable(obj, key, func, uselist, **kwargs)
-
+
def register_clean(self, obj):
if obj in self.new:
self.new.remove(obj)
@@ -128,29 +147,29 @@ class UnitOfWork(object):
delattr(obj, '_sa_insert_order')
self.identity_map[obj._instance_key] = obj
attribute_manager.commit(obj)
-
+
def register_new(self, obj):
if hasattr(obj, '_instance_key'):
raise InvalidRequestError("Object '%s' already has an identity - it cant be registered as new" % repr(obj))
if obj not in self.new:
self.new.add(obj)
obj._sa_insert_order = len(self.new)
-
+
def register_deleted(self, obj):
if obj not in self.deleted:
self._validate_obj(obj)
- self.deleted.add(obj)
-
+ self.deleted.add(obj)
+
def locate_dirty(self):
return util.Set([x for x in self.identity_map.values() if x not in self.deleted and attribute_manager.is_modified(x)])
-
+
def flush(self, session, objects=None):
# this context will track all the objects we want to save/update/delete,
# and organize a hierarchical dependency structure. it also handles
# communication with the mappers and relationships to fire off SQL
# and synchronize attributes between related objects.
echo = logging.is_info_enabled(self.logger)
-
+
flush_context = UOWTransaction(self, session)
# create the set of all objects we want to operate upon
@@ -166,8 +185,8 @@ class UnitOfWork(object):
# store objects whose fate has been decided
processed = util.Set()
-
-
+
+
# put all saves/updates into the flush context. detect orphans and throw them into deleted.
for obj in self.new.union(dirty).intersection(objset).difference(self.deleted):
if obj in processed:
@@ -181,13 +200,13 @@ class UnitOfWork(object):
else:
flush_context.register_object(obj)
processed.add(obj)
-
+
# put all remaining deletes into the flush context.
for obj in self.deleted:
if (objset is not None and not obj in objset) or obj in processed:
continue
flush_context.register_object(obj, isdelete=True)
-
+
trans = session.create_transaction(autoflush=False)
flush_context.transaction = trans
try:
@@ -196,12 +215,14 @@ class UnitOfWork(object):
trans.rollback()
raise
trans.commit()
-
+
flush_context.post_exec()
-
+
class UOWTransaction(object):
- """handles the details of organizing and executing transaction tasks
- during a UnitOfWork object's flush() operation."""
+ """Handles the details of organizing and executing transaction
+ tasks during a UnitOfWork object's flush() operation.
+ """
+
def __init__(self, uow, session):
self.uow = uow
self.session = session
@@ -211,37 +232,41 @@ class UOWTransaction(object):
self.tasks = {}
self.logger = logging.instance_logger(self)
self.echo = uow.echo
-
+
echo = logging.echo_property()
-
+
def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs):
- """adds an object to this UOWTransaction to be updated in the database.
+ """Add an object to this UOWTransaction to be updated in the database.
+
+ `isdelete` indicates whether the object is to be deleted or
+ saved (update/inserted).
- 'isdelete' indicates whether the object is to be deleted or saved (update/inserted).
+ `listonly` indicates that only this object's dependency
+ relationships should be refreshed/updated to reflect a recent
+ save/upcoming delete operation, but not a full save/delete
+ operation on the object itself, unless an additional
+ save/delete registration is entered for the object.
+ """
- 'listonly', indicates that only this object's dependency relationships should be
- refreshed/updated to reflect a recent save/upcoming delete operation, but not a full
- save/delete operation on the object itself, unless an additional save/delete
- registration is entered for the object."""
#print "REGISTER", repr(obj), repr(getattr(obj, '_instance_key', None)), str(isdelete), str(listonly)
-
+
# if object is not in the overall session, do nothing
if not self.uow._is_valid(obj):
return
-
+
mapper = object_mapper(obj)
self.mappers.add(mapper)
task = self.get_task_by_mapper(mapper)
if postupdate:
task.append_postupdate(obj, post_update_cols)
return
-
+
# for a cyclical task, things need to be sorted out already,
# so this object should have already been added to the appropriate sub-task
# can put an assertion here to make sure....
if task.circular:
return
-
+
task.append(obj, listonly, isdelete=isdelete, **kwargs)
def unregister_object(self, obj):
@@ -250,19 +275,21 @@ class UOWTransaction(object):
task = self.get_task_by_mapper(mapper)
if obj in task.objects:
task.delete(obj)
-
-
+
def is_deleted(self, obj):
mapper = object_mapper(obj)
task = self.get_task_by_mapper(mapper)
return task.is_deleted(obj)
-
+
def get_task_by_mapper(self, mapper, dontcreate=False):
- """every individual mapper involved in the transaction has a single
- corresponding UOWTask object, which stores all the operations involved
- with that mapper as well as operations dependent on those operations.
- this method returns or creates the single per-transaction instance of
- UOWTask that exists for that mapper."""
+ """Every individual mapper involved in the transaction has a
+ single corresponding UOWTask object, which stores all the
+ operations involved with that mapper as well as operations
+ dependent on those operations. this method returns or creates
+ the single per-transaction instance of UOWTask that exists for
+ that mapper.
+ """
+
try:
return self.tasks[mapper]
except KeyError:
@@ -271,45 +298,52 @@ class UOWTransaction(object):
task = UOWTask(self, mapper)
task.mapper.register_dependencies(self)
return task
-
+
def register_dependency(self, mapper, dependency):
- """called by mapper.PropertyLoader to register the objects handled by
- one mapper being dependent on the objects handled by another."""
+ """Called by ``mapper.PropertyLoader`` to register the objects
+ handled by one mapper being dependent on the objects handled
+ by another.
+ """
+
# correct for primary mapper (the mapper offcially associated with the class)
# also convert to the "base mapper", the parentmost task at the top of an inheritance chain
# dependency sorting is done via non-inheriting mappers only, dependencies between mappers
# in the same inheritance chain is done at the per-object level
mapper = mapper.primary_mapper().base_mapper()
dependency = dependency.primary_mapper().base_mapper()
-
+
self.dependencies[(mapper, dependency)] = True
def register_processor(self, mapper, processor, mapperfrom):
- """called by mapper.PropertyLoader to register itself as a "processor", which
- will be associated with a particular UOWTask, and be given a list of "dependent"
- objects corresponding to another UOWTask to be processed, either after that secondary
- task saves its objects or before it deletes its objects."""
+ """Called by ``mapper.PropertyLoader`` to register itself as a
+ *processor*, which will be associated with a particular
+ UOWTask, and be given a list of *dependent* objects
+ corresponding to another UOWTask to be processed, either after
+ that secondary task saves its objects or before it deletes its
+ objects.
+ """
+
# when the task from "mapper" executes, take the objects from the task corresponding
# to "mapperfrom"'s list of save/delete objects, and send them to "processor"
# for dependency processing
-
+
#print "registerprocessor", str(mapper), repr(processor), repr(processor.key), str(mapperfrom)
-
+
# correct for primary mapper (the mapper offcially associated with the class)
mapper = mapper.primary_mapper()
mapperfrom = mapperfrom.primary_mapper()
-
+
task = self.get_task_by_mapper(mapper)
targettask = self.get_task_by_mapper(mapperfrom)
up = UOWDependencyProcessor(processor, targettask)
task.dependencies.add(up)
def execute(self):
- # ensure that we have a UOWTask for every mapper that will be involved
+ # ensure that we have a UOWTask for every mapper that will be involved
# in the topological sort
[self.get_task_by_mapper(m) for m in self._get_noninheriting_mappers()]
-
- # pre-execute dependency processors. this process may
+
+ # pre-execute dependency processors. this process may
# result in new tasks, objects and/or dependency processors being added,
# particularly with 'delete-orphan' cascade rules.
# keep running through the full list of tasks until all
@@ -322,7 +356,7 @@ class UOWTransaction(object):
ret = True
if not ret:
break
-
+
head = self._sort_dependencies()
if self.echo:
if head is None:
@@ -332,11 +366,13 @@ class UOWTransaction(object):
if head is not None:
head.execute(self)
self.logger.info("Execute Complete")
-
+
def post_exec(self):
- """after an execute/flush is completed, all of the objects and lists that have
- been flushed are updated in the parent UnitOfWork object to mark them as clean."""
-
+ """After an execute/flush is completed, all of the objects and
+ lists that have been flushed are updated in the parent
+ UnitOfWork object to mark them as clean.
+ """
+
for task in self.tasks.values():
for elem in task.objects.values():
if elem.isdelete:
@@ -345,8 +381,14 @@ class UOWTransaction(object):
self.uow.register_clean(elem.obj)
def _sort_dependencies(self):
- """creates a hierarchical tree of dependent tasks. the root node is returned.
- when the root node is executed, it also executes its child tasks recursively."""
+ """Create a hierarchical tree of dependent tasks.
+
+ The root node is returned.
+
+ When the root node is executed, it also executes its child
+ tasks recursively.
+ """
+
def sort_hier(node):
if node is None:
return None
@@ -361,7 +403,7 @@ class UOWTransaction(object):
if t is not None:
task.childtasks.append(t)
return task
-
+
mappers = self._get_noninheriting_mappers()
head = DependencySorter(self.dependencies, list(mappers)).sort(allow_all_cycles=True)
self.logger.debug("Dependency sort:\n"+ str(head))
@@ -369,16 +411,24 @@ class UOWTransaction(object):
return task
def _get_noninheriting_mappers(self):
- """returns a list of UOWTasks whose mappers are not inheriting from the mapper of another UOWTask.
- i.e., this returns the root UOWTasks for all the inheritance hierarchies represented in this UOWTransaction."""
+ """Return a list of UOWTasks whose mappers are not inheriting
+ from the mapper of another UOWTask.
+
+ I.e., this returns the root UOWTasks for all the inheritance
+ hierarchies represented in this UOWTransaction.
+ """
+
mappers = util.Set()
for task in self.tasks.values():
base = task.mapper.base_mapper()
mappers.add(base)
return mappers
-
+
class UOWTask(object):
- """represents the full list of objects that are to be saved/deleted by a specific Mapper."""
+ """Represents the full list of objects that are to be
+ saved/deleted by a specific Mapper.
+ """
+
def __init__(self, uowtransaction, mapper, circular_parent=None):
if not circular_parent:
uowtransaction.tasks[mapper] = self
@@ -390,7 +440,7 @@ class UOWTask(object):
self.mapper = mapper
# a dictionary mapping object instances to a corresponding UOWTaskElement.
- # Each UOWTaskElement represents one instance which is to be saved or
+ # Each UOWTaskElement represents one instance which is to be saved or
# deleted by this UOWTask's Mapper.
# in the case of the row-based "circular sort", the UOWTaskElement may
# also reference further UOWTasks which are dependent on that UOWTaskElement.
@@ -400,7 +450,7 @@ class UOWTask(object):
# before deletes, to synchronize data to dependent objects
self.dependencies = util.Set()
- # a list of UOWTasks that are dependent on this UOWTask, which
+ # a list of UOWTasks that are dependent on this UOWTask, which
# are to be executed after this UOWTask performs saves and post-save
# dependency processing, and before pre-delete processing and deletes
self.childtasks = []
@@ -423,12 +473,19 @@ class UOWTask(object):
return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0
def append(self, obj, listonly = False, childtask = None, isdelete = False):
- """appends an object to this task, to be either saved or deleted depending on the
- 'isdelete' attribute of this UOWTask. 'listonly' indicates that the object should
- only be processed as a dependency and not actually saved/deleted. if the object
- already exists with a 'listonly' flag of False, it is kept as is. 'childtask' is used
- internally when creating a hierarchical list of self-referential tasks, to assign
- dependent operations at the per-object instead of per-task level. """
+ """Append an object to this task, to be either saved or deleted depending on the
+ 'isdelete' attribute of this UOWTask.
+
+ `listonly` indicates that the object should only be processed
+ as a dependency and not actually saved/deleted. if the object
+ already exists with a `listonly` flag of False, it is kept as
+ is.
+
+ `childtask` is used internally when creating a hierarchical
+ list of self-referential tasks, to assign dependent operations
+ at the per-object instead of per-task level.
+ """
+
try:
rec = self.objects[obj]
retval = False
@@ -464,16 +521,24 @@ class UOWTask(object):
task.mapper.delete_obj(task.todelete_objects, trans)
def execute(self, trans):
- """executes this UOWTask. saves objects to be saved, processes all dependencies
- that have been registered, and deletes objects to be deleted. """
+ """Execute this UOWTask.
+
+ Save objects to be saved, process all dependencies that have
+ been registered, and delete objects to be deleted.
+ """
UOWExecutor().execute(trans, self)
def polymorphic_tasks(self):
- """returns an iteration consisting of this UOWTask, and all UOWTasks whose
- mappers are inheriting descendants of this UOWTask's mapper. UOWTasks are returned in order
- of their hierarchy to each other, meaning if UOWTask B's mapper inherits from UOWTask A's
- mapper, then UOWTask B will appear after UOWTask A in the iteration."""
+ """Return an iteration consisting of this UOWTask, and all
+ UOWTasks whose mappers are inheriting descendants of this
+ UOWTask's mapper.
+
+ UOWTasks are returned in order of their hierarchy to each
+ other, meaning if UOWTask B's mapper inherits from UOWTask A's
+ mapper, then UOWTask B will appear after UOWTask A in the
+ iteration.
+ """
# first us
yield self
@@ -525,22 +590,38 @@ class UOWTask(object):
for rec in self.objects.values():
yield rec
- polymorphic_tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if not rec.isdelete])
- polymorphic_todelete_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if rec.isdelete])
- tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=False) if not rec.isdelete])
- todelete_elements = property(lambda self:[rec for rec in self.get_elements(polymorphic=False) if rec.isdelete])
- tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
- todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False) if rec.obj is not None and not rec.listonly and rec.isdelete is True])
- polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True) if rec.obj is not None and not rec.listonly and rec.isdelete is False])
+ polymorphic_tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True)
+ if not rec.isdelete])
+
+ polymorphic_todelete_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True)
+ if rec.isdelete])
+
+ tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=False)
+ if not rec.isdelete])
+
+ todelete_elements = property(lambda self:[rec for rec in self.get_elements(polymorphic=False)
+ if rec.isdelete])
+
+ tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False)
+ if rec.obj is not None and not rec.listonly and rec.isdelete is False])
+
+ todelete_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=False)
+ if rec.obj is not None and not rec.listonly and rec.isdelete is True])
+
+ polymorphic_tosave_objects = property(lambda self:[rec.obj for rec in self.get_elements(polymorphic=True)
+ if rec.obj is not None and not rec.listonly and rec.isdelete is False])
def _sort_circular_dependencies(self, trans, cycles):
- """for a single task, creates a hierarchical tree of "subtasks" which associate
- specific dependency actions with individual objects. This is used for a
- "cyclical" task, or a task where elements
- of its object list contain dependencies on each other.
+ """For a single task, create a hierarchical tree of *subtasks*
+ which associate specific dependency actions with individual
+ objects. This is used for a *cyclical* task, or a task where
+ elements of its object list contain dependencies on each
+ other.
+
+ This is not the normal case; this logic only kicks in when
+ something like a hierarchical tree is being represented.
+ """
- this is not the normal case; this logic only kicks in when something like
- a hierarchical tree is being represented."""
allobjects = []
for task in cycles:
allobjects += [e.obj for e in task.get_elements(polymorphic=True)]
@@ -560,6 +641,7 @@ class UOWTask(object):
# the final tree, for the purposes of holding new UOWDependencyProcessors
# which process small sub-sections of dependent parent/child operations
dependencies = {}
+
def get_dependency_task(obj, depprocessor):
try:
dp = dependencies[obj]
@@ -652,6 +734,7 @@ class UOWTask(object):
# create a tree of UOWTasks corresponding to the tree of object instances
# created by the DependencySorter
+
def make_task_tree(node, parenttask, nexttasks):
originating_task = object_to_original_task[node.item]
t = nexttasks.get(originating_task, None)
@@ -686,7 +769,6 @@ class UOWTask(object):
uowdumper.UOWDumper(self, buf)
return buf.getvalue()
-
def __repr__(self):
if self.mapper is not None:
if self.mapper.__class__.__name__ == 'Mapper':
@@ -696,69 +778,104 @@ class UOWTask(object):
else:
name = '(none)'
return ("UOWTask(%d) Mapper: '%s'" % (id(self), name))
-
+
class UOWTaskElement(object):
- """an element within a UOWTask. corresponds to a single object instance
- to be saved, deleted, or just part of the transaction as a placeholder for
- further dependencies (i.e. 'listonly').
- in the case of self-referential mappers, may also store a list of childtasks,
- further UOWTasks containing objects dependent on this element's object instance."""
+ """An element within a UOWTask.
+
+ Corresponds to a single object instance to be saved, deleted, or
+ just part of the transaction as a placeholder for further
+ dependencies (i.e. 'listonly').
+
+ In the case of self-referential mappers, may also store a list of
+ childtasks, further UOWTasks containing objects dependent on this
+ element's object instance.
+ """
+
def __init__(self, obj):
self.obj = obj
self.__listonly = True
self.childtasks = []
self.__isdelete = False
self.__preprocessed = {}
+
def _get_listonly(self):
return self.__listonly
+
def _set_listonly(self, value):
- """set_listonly is a one-way setter, will only go from True to False."""
+ """Set_listonly is a one-way setter, will only go from True to False."""
+
if not value and self.__listonly:
self.__listonly = False
self.clear_preprocessed()
+
def _get_isdelete(self):
return self.__isdelete
+
def _set_isdelete(self, value):
if self.__isdelete is not value:
self.__isdelete = value
self.clear_preprocessed()
+
listonly = property(_get_listonly, _set_listonly)
isdelete = property(_get_isdelete, _set_isdelete)
-
+
def mark_preprocessed(self, processor):
- """marks this element as "preprocessed" by a particular UOWDependencyProcessor. preprocessing is the step
- which sweeps through all the relationships on all the objects in the flush transaction and adds other objects
- which are also affected, In some cases it can switch an object from "tosave" to "todelete". changes to the state
- of this UOWTaskElement will reset all "preprocessed" flags, causing it to be preprocessed again. When all UOWTaskElements
- have been fully preprocessed by all UOWDependencyProcessors, then the topological sort can be done."""
+ """Mark this element as *preprocessed* by a particular UOWDependencyProcessor.
+
+ Preprocessing is the step which sweeps through all the
+ relationships on all the objects in the flush transaction and
+ adds other objects which are also affected, In some cases it
+ can switch an object from *tosave* to *todelete*.
+
+ Changes to the state of this UOWTaskElement will reset all
+ *preprocessed* flags, causing it to be preprocessed again.
+ When all UOWTaskElements have been fully preprocessed by all
+ UOWDependencyProcessors, then the topological sort can be
+ done.
+ """
+
self.__preprocessed[processor] = True
+
def is_preprocessed(self, processor):
return self.__preprocessed.get(processor, False)
+
def clear_preprocessed(self):
self.__preprocessed.clear()
+
def __repr__(self):
return "UOWTaskElement/%d: %s/%d %s" % (id(self), self.obj.__class__.__name__, id(self.obj), (self.listonly and 'listonly' or (self.isdelete and 'delete' or 'save')) )
class UOWDependencyProcessor(object):
- """in between the saving and deleting of objects, process "dependent" data, such as filling in
- a foreign key on a child item from a new primary key, or deleting association rows before a
- delete. This object acts as a proxy to a DependencyProcessor."""
+ """In between the saving and deleting of objects, process
+ *dependent* data, such as filling in a foreign key on a child item
+ from a new primary key, or deleting association rows before a
+ delete. This object acts as a proxy to a DependencyProcessor.
+ """
+
def __init__(self, processor, targettask):
self.processor = processor
self.targettask = targettask
+
def __eq__(self, other):
return other.processor is self.processor and other.targettask is self.targettask
+
def __hash__(self):
return hash((self.processor, self.targettask))
-
+
def preexecute(self, trans):
- """traverses all objects handled by this dependency processor and locates additional objects which should be
- part of the transaction, such as those affected deletes, orphans to be deleted, etc. Returns True if any
- objects were preprocessed, or False if no objects were preprocessed."""
+ """Traverse all objects handled by this dependency processor
+ and locate additional objects which should be part of the
+ transaction, such as those affected deletes, orphans to be
+ deleted, etc.
+
+ Return True if any objects were preprocessed, or False if no
+ objects were preprocessed.
+ """
+
def getobj(elem):
elem.mark_preprocessed(self)
return elem.obj
-
+
ret = False
elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)]
if len(elements):
@@ -770,24 +887,25 @@ class UOWDependencyProcessor(object):
ret = True
self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
return ret
-
+
def execute(self, trans, delete):
if not delete:
self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None], trans, delete=False)
- else:
+ else:
self.processor.process_dependencies(self.targettask, [elem.obj for elem in self.targettask.polymorphic_todelete_elements if elem.obj is not None], trans, delete=True)
def get_object_dependencies(self, obj, trans, passive):
return self.processor.get_object_dependencies(obj, trans, passive=passive)
- def whose_dependent_on_who(self, obj, o):
+ def whose_dependent_on_who(self, obj, o):
return self.processor.whose_dependent_on_who(obj, o)
def branch(self, task):
return UOWDependencyProcessor(self.processor, task)
class UOWExecutor(object):
- """encapsulates the execution traversal of a UOWTransaction structure."""
+ """Encapsulates the execution traversal of a UOWTransaction structure."""
+
def execute(self, trans, task, isdelete=None):
if isdelete is not True:
self.execute_save_steps(trans, task)
@@ -799,10 +917,10 @@ class UOWExecutor(object):
def delete_objects(self, trans, task):
task._delete_objects(trans)
-
+
def execute_dependency(self, trans, dep, isdelete):
dep.execute(trans, isdelete)
-
+
def execute_save_steps(self, trans, task):
if task.circular is not None:
self.execute_save_steps(trans, task.circular)
@@ -813,8 +931,8 @@ class UOWExecutor(object):
self.execute_dependencies(trans, task, False)
self.execute_dependencies(trans, task, True)
self.execute_childtasks(trans, task, False)
-
- def execute_delete_steps(self, trans, task):
+
+ def execute_delete_steps(self, trans, task):
if task.circular is not None:
self.execute_delete_steps(trans, task.circular)
else:
@@ -839,24 +957,23 @@ class UOWExecutor(object):
for polytask in task.polymorphic_tasks():
for child in polytask.childtasks:
self.execute(trans, child, isdelete)
-
+
def execute_cyclical_dependencies(self, trans, task, isdelete):
for polytask in task.polymorphic_tasks():
for dep in polytask.cyclical_dependencies:
self.execute_dependency(trans, dep, isdelete)
-
+
def execute_per_element_childtasks(self, trans, task, isdelete):
for polytask in task.polymorphic_tasks():
for element in polytask.tosave_elements + polytask.todelete_elements:
self.execute_element_childtasks(trans, element, isdelete)
-
+
def execute_element_childtasks(self, trans, element, isdelete):
for child in element.childtasks:
self.execute(trans, child, isdelete)
-
+
class DependencySorter(topological.QueueDependencySorter):
pass
attribute_manager = UOWAttributeManager()
-
diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py
index e569d0d9ae..166e3b8795 100644
--- a/lib/sqlalchemy/orm/uowdumper.py
+++ b/lib/sqlalchemy/orm/uowdumper.py
@@ -1,6 +1,6 @@
from sqlalchemy.orm import unitofwork
-"""dumps out a string representation of a UOWTask structure"""
+"""Dumps out a string representation of a UOWTask structure"""
class UOWDumper(unitofwork.UOWExecutor):
def __init__(self, task, buf, verbose=False):
@@ -34,15 +34,15 @@ class UOWDumper(unitofwork.UOWExecutor):
super(UOWDumper, self).execute(trans, task, isdelete)
finally:
self.indent -= 1
- if self.starttask.is_empty():
+ if self.starttask.is_empty():
self.buf.write(self._indent() + " |- (empty task)\n")
else:
self.buf.write(self._indent() + " |----\n")
- self.buf.write(self._indent() + "\n")
+ self.buf.write(self._indent() + "\n")
self.starttask = oldstarttask
self.headers = oldheaders
-
+
def save_objects(self, trans, task):
# sort elements to be inserted by insert order
def comparator(a, b):
@@ -59,7 +59,7 @@ class UOWDumper(unitofwork.UOWExecutor):
else:
y = b.obj._sa_insert_order
return cmp(x, y)
-
+
l = list(task.polymorphic_tosave_elements)
l.sort(comparator)
for rec in l:
@@ -68,7 +68,7 @@ class UOWDumper(unitofwork.UOWExecutor):
self.header("Save elements"+ self._inheritance_tag(task))
self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n")
self.closeheader()
-
+
def delete_objects(self, trans, task):
for rec in task.polymorphic_todelete_elements:
if rec.listonly:
@@ -86,7 +86,8 @@ class UOWDumper(unitofwork.UOWExecutor):
return ""
def header(self, text):
- """write a given header just once"""
+ """Write a given header just once."""
+
if not self.verbose:
return
try:
@@ -106,7 +107,7 @@ class UOWDumper(unitofwork.UOWExecutor):
def execute_save_steps(self, trans, task):
super(UOWDumper, self).execute_save_steps(trans, task)
- def execute_delete_steps(self, trans, task):
+ def execute_delete_steps(self, trans, task):
super(UOWDumper, self).execute_delete_steps(trans, task)
def execute_dependencies(self, trans, task, isdelete=None):
@@ -116,12 +117,12 @@ class UOWDumper(unitofwork.UOWExecutor):
self.header("Child tasks" + self._inheritance_tag(task))
super(UOWDumper, self).execute_childtasks(trans, task, isdelete)
self.closeheader()
-
+
def execute_cyclical_dependencies(self, trans, task, isdelete):
self.header("Cyclical %s dependencies" % (isdelete and "delete" or "save"))
super(UOWDumper, self).execute_cyclical_dependencies(trans, task, isdelete)
self.closeheader()
-
+
def execute_per_element_childtasks(self, trans, task, isdelete):
super(UOWDumper, self).execute_per_element_childtasks(trans, task, isdelete)
@@ -129,7 +130,7 @@ class UOWDumper(unitofwork.UOWExecutor):
self.header("%s subelements of UOWTaskElement(%s)" % ((isdelete and "Delete" or "Save"), hex(id(element))))
super(UOWDumper, self).execute_element_childtasks(trans, element, isdelete)
self.closeheader()
-
+
def _dump_processor(self, proc, deletes):
if deletes:
val = proc.targettask.polymorphic_todelete_elements
@@ -138,18 +139,18 @@ class UOWDumper(unitofwork.UOWExecutor):
if self.verbose:
self.buf.write(self._indent() + " |- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % (
- repr(proc.processor.key),
+ repr(proc.processor.key),
("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")),
- hex(id(proc)),
+ hex(id(proc)),
self._repr_task(proc.targettask))
)
elif False:
self.buf.write(self._indent() + " |- %s attribute on %s\n" % (
- repr(proc.processor.key),
+ repr(proc.processor.key),
("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")),
)
)
-
+
if len(val) == 0:
if self.verbose:
self.buf.write(self._indent() + " |- " + "(no objects)\n")
@@ -183,7 +184,7 @@ class UOWDumper(unitofwork.UOWExecutor):
return ("UOWTask(%s->%s, %s)" % (hex(id(task.circular_parent)), hex(id(task)), name))
else:
return ("UOWTask(%s, %s)" % (hex(id(task)), name))
-
+
def _repr_task_class(self, task):
if task.mapper is not None and task.mapper.__class__.__name__ == 'Mapper':
return task.mapper.class_.__name__
@@ -195,4 +196,3 @@ class UOWDumper(unitofwork.UOWExecutor):
def _indent(self):
return " |" * self.indent
-
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 9de58ce310..d1ae9f7968 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -7,8 +7,10 @@
from sqlalchemy import sql, util, exceptions
all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none"])
+
class CascadeOptions(object):
- """keeps track of the options sent to relation().cascade"""
+ """Keeps track of the options sent to relation().cascade"""
+
def __init__(self, arg=""):
values = util.Set([c.strip() for c in arg.split(',')])
self.delete_orphan = "delete-orphan" in values
@@ -48,7 +50,7 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
m[c.name] = c
types[c.name] = c.type
colnamemaps[table] = m
-
+
def col(name, table):
try:
return colnamemaps[table][name]
@@ -64,13 +66,18 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
return sql.union_all(*result).alias(aliasname)
class TranslatingDict(dict):
- """a dictionary that stores ColumnElement objects as keys. incoming ColumnElement
- keys are translated against those of an underling FromClause for all operations.
- This way the columns from any Selectable that is derived from or underlying this
- TranslatingDict's selectable can be used as keys."""
+ """A dictionary that stores ColumnElement objects as keys.
+
+ Incoming ColumnElement keys are translated against those of an
+ underling FromClause for all operations. This way the columns
+ from any Selectable that is derived from or underlying this
+ TranslatingDict's selectable can be used as keys.
+ """
+
def __init__(self, selectable):
super(TranslatingDict, self).__init__()
self.selectable = selectable
+
def __translate_col(self, col):
ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
#if col is not ourcol:
@@ -79,27 +86,33 @@ class TranslatingDict(dict):
return col
else:
return ourcol
+
def __getitem__(self, col):
return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
+
def has_key(self, col):
return super(TranslatingDict, self).has_key(self.__translate_col(col))
+
def __setitem__(self, col, value):
return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value)
+
def __contains__(self, col):
return self.has_key(col)
+
def setdefault(self, col, value):
return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)
class BinaryVisitor(sql.ClauseVisitor):
def __init__(self, func):
self.func = func
+
def visit_binary(self, binary):
self.func(binary)
def instance_str(instance):
- """return a string describing an instance"""
+ """Return a string describing an instance."""
+
return instance.__class__.__name__ + "@" + hex(id(instance))
-
+
def attribute_str(instance, attribute):
return instance_str(instance) + "." + attribute
-
\ No newline at end of file
diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py
index 640277270f..787fd059f2 100644
--- a/lib/sqlalchemy/pool.py
+++ b/lib/sqlalchemy/pool.py
@@ -5,17 +5,20 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""provides a connection pool implementation, which optionally manages connections
-on a thread local basis. Also provides a DBAPI2 transparency layer so that pools can
-be managed automatically, based on module type and connect arguments,
- simply by calling regular DBAPI connect() methods."""
+"""Provide a connection pool implementation, which optionally manages
+connections on a thread local basis.
+
+Also provides a DBAPI2 transparency layer so that pools can be managed
+automatically, based on module type and connect arguments, simply by
+calling regular DBAPI connect() methods.
+"""
import weakref, string, time, sys, traceback
try:
import cPickle as pickle
except:
import pickle
-
+
from sqlalchemy import exceptions, logging
from sqlalchemy import queue as Queue
@@ -27,72 +30,101 @@ except:
proxies = {}
def manage(module, **params):
- """given a DBAPI2 module and pool management parameters, returns a proxy for the module
- that will automatically pool connections, creating new connection pools for each
- distinct set of connection arguments sent to the decorated module's connect() function.
+ """Return a proxy for module that automatically pools connections.
+
+ Given a DBAPI2 module and pool management parameters, returns a
+ proxy for the module that will automatically pool connections,
+ creating new connection pools for each distinct set of connection
+ arguments sent to the decorated module's connect() function.
Arguments:
- module : a DBAPI2 database module.
- poolclass=QueuePool : the class used by the pool module to provide pooling.
-
- Options:
- See Pool for options.
+ module
+ A DBAPI2 database module.
+
+ poolclass
+ The class used by the pool module to provide pooling.
+ Defaults to ``QueuePool``.
+
+ See the ``Pool`` class for options.
"""
try:
return proxies[module]
except KeyError:
- return proxies.setdefault(module, _DBProxy(module, **params))
+ return proxies.setdefault(module, _DBProxy(module, **params))
def clear_managers():
- """removes all current DBAPI2 managers. all pools and connections are disposed."""
+ """Remove all current DBAPI2 managers.
+
+ All pools and connections are disposed.
+ """
+
for manager in proxies.values():
manager.close()
proxies.clear()
-
+
class Pool(object):
- """Base Pool class. This is an abstract class, which is implemented by various subclasses
- including:
-
- QueuePool - pools multiple connections using Queue.Queue
-
- SingletonThreadPool - stores a single connection per execution thread
-
- NullPool - doesnt do any pooling; opens and closes connections
-
- AssertionPool - stores only one connection, and asserts that only one connection is checked out at a time.
-
- the main argument, "creator", is a callable function that returns a newly connected DBAPI connection
- object.
-
+ """Base Pool class.
+
+ This is an abstract class, which is implemented by various
+ subclasses including:
+
+ QueuePool
+ Pools multiple connections using ``Queue.Queue``.
+
+ SingletonThreadPool
+ Stores a single connection per execution thread.
+
+ NullPool
+ Doesn't do any pooling; opens and closes connections.
+
+ AssertionPool
+ Stores only one connection, and asserts that only one connection
+ is checked out at a time.
+
+ The main argument, `creator`, is a callable function that returns
+ a newly connected DBAPI connection object.
+
Options that are understood by Pool are:
-
- echo=False : if set to True, connections being pulled and retrieved from/to the pool will
- be logged to the standard output, as well as pool sizing information. Echoing can also
- be achieved by enabling logging for the "sqlalchemy.pool" namespace.
-
- use_threadlocal=True : if set to True, repeated calls to connect() within the same
- application thread will be guaranteed to return the same connection object, if one has
- already been retrieved from the pool and has not been returned yet. This allows code to
- retrieve a connection from the pool, and then while still holding on to that connection,
- to call other functions which also ask the pool for a connection of the same arguments;
- those functions will act upon the same connection that the calling method is using.
-
- recycle=-1 : if set to non -1, a number of seconds between connection recycling, which
- means upon checkout, if this timeout is surpassed the connection will be closed and replaced
- with a newly opened connection.
-
- auto_close_cursors = True : cursors, returned by connection.cursor(), are tracked and are
- automatically closed when the connection is returned to the pool. some DBAPIs like MySQLDB
- become unstable if cursors remain open.
-
- disallow_open_cursors = False : if auto_close_cursors is False, and disallow_open_cursors is True,
- will raise an exception if an open cursor is detected upon connection checkin.
-
- If auto_close_cursors and disallow_open_cursors are both False, then no cursor processing
- occurs upon checkin.
-
+
+ echo
+ If set to True, connections being pulled and retrieved from/to
+ the pool will be logged to the standard output, as well as pool
+ sizing information. Echoing can also be achieved by enabling
+ logging for the "sqlalchemy.pool" namespace. Defaults to False.
+
+ use_threadlocal
+ If set to True, repeated calls to ``connect()`` within the same
+ application thread will be guaranteed to return the same
+ connection object, if one has already been retrieved from the
+ pool and has not been returned yet. This allows code to retrieve
+ a connection from the pool, and then while still holding on to
+ that connection, to call other functions which also ask the pool
+ for a connection of the same arguments; those functions will act
+ upon the same connection that the calling method is using.
+ Defaults to True.
+
+ recycle
+ If set to non -1, a number of seconds between connection
+ recycling, which means upon checkout, if this timeout is
+ surpassed the connection will be closed and replaced with a
+ newly opened connection. Defaults to -1.
+
+ auto_close_cursors
+ Cursors, returned by ``connection.cursor()``, are tracked and
+ are automatically closed when the connection is returned to the
+ pool. Some DBAPIs like MySQLDB become unstable if cursors
+ remain open. Defaults to True.
+
+ disallow_open_cursors
+ If `auto_close_cursors` is False, and `disallow_open_cursors` is
+ True, will raise an exception if an open cursor is detected upon
+ connection checkin. Defaults to False.
+
+ If `auto_close_cursors` and `disallow_open_cursors` are both
+ False, then no cursor processing occurs upon checkin.
"""
+
def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=False, auto_close_cursors=True, disallow_open_cursors=False):
self.logger = logging.instance_logger(self)
self._threadconns = weakref.WeakValueDictionary()
@@ -103,17 +135,17 @@ class Pool(object):
self.disallow_open_cursors = disallow_open_cursors
self.echo = echo
echo = logging.echo_property()
-
+
def unique_connection(self):
return _ConnectionFairy(self).checkout()
-
+
def create_connection(self):
return _ConnectionRecord(self)
-
+
def connect(self):
if not self._use_threadlocal:
return _ConnectionFairy(self).checkout()
-
+
try:
return self._threadconns[thread.get_ident()].connfairy().checkout()
except KeyError:
@@ -126,13 +158,13 @@ class Pool(object):
def get(self):
return self.do_get()
-
+
def do_get(self):
raise NotImplementedError()
-
+
def do_return_conn(self, conn):
raise NotImplementedError()
-
+
def status(self):
raise NotImplementedError()
@@ -141,19 +173,21 @@ class Pool(object):
def dispose(self):
raise NotImplementedError()
-
class _ConnectionRecord(object):
def __init__(self, pool):
self.__pool = pool
self.connection = self.__connect()
+
def close(self):
self.__pool.log("Closing connection %s" % repr(self.connection))
self.connection.close()
+
def invalidate(self):
self.__pool.log("Invalidate connection %s" % repr(self.connection))
self.__close()
self.connection = None
+
def get_connection(self):
if self.connection is None:
self.connection = self.__connect()
@@ -162,12 +196,14 @@ class _ConnectionRecord(object):
self.__close()
self.connection = self.__connect()
return self.connection
+
def __close(self):
try:
self.__pool.log("Closing connection %s" % (repr(self.connection)))
self.connection.close()
except Exception, e:
self.__pool.log("Connection %s threw an error on close: %s" % (repr(self.connection), str(e)))
+
def __connect(self):
try:
self.starttime = time.time()
@@ -179,12 +215,14 @@ class _ConnectionRecord(object):
raise
class _ThreadFairy(object):
- """marks a thread identifier as owning a connection, for a thread local pool."""
+ """Mark a thread identifier as owning a connection, for a thread local pool."""
+
def __init__(self, connfairy):
self.connfairy = weakref.ref(connfairy)
-
+
class _ConnectionFairy(object):
- """proxies a DBAPI connection object and provides return-on-dereference support"""
+ """Proxy a DBAPI connection object and provides return-on-dereference support."""
+
def __init__(self, pool):
self._threadfairy = _ThreadFairy(self)
self._cursors = weakref.WeakKeyDictionary()
@@ -199,6 +237,7 @@ class _ConnectionFairy(object):
raise
if self.__pool.echo:
self.__pool.log("Connection %s checked out from pool" % repr(self.connection))
+
def invalidate(self):
if self.connection is None:
raise exceptions.InvalidRequestError("This connection is closed")
@@ -206,29 +245,36 @@ class _ConnectionFairy(object):
self.connection = None
self._cursors = None
self._close()
+
def cursor(self, *args, **kwargs):
try:
return _CursorFairy(self, self.connection.cursor(*args, **kwargs))
except Exception, e:
self.invalidate()
raise
+
def __getattr__(self, key):
return getattr(self.connection, key)
+
def checkout(self):
if self.connection is None:
raise exceptions.InvalidRequestError("This connection is closed")
self.__counter +=1
- return self
+ return self
+
def close_open_cursors(self):
if self._cursors is not None:
for c in list(self._cursors):
c.close()
+
def close(self):
self.__counter -=1
if self.__counter == 0:
self._close()
+
def __del__(self):
self._close()
+
def _close(self):
if self._cursors is not None:
# cursors should be closed before connection is returned to the pool. some dbapis like
@@ -252,31 +298,38 @@ class _ConnectionFairy(object):
self._connection_record = None
self._threadfairy = None
self._cursors = None
-
+
class _CursorFairy(object):
def __init__(self, parent, cursor):
self.__parent = parent
self.__parent._cursors[self]=True
self.cursor = cursor
+
def invalidate(self):
self.__parent.invalidate()
+
def close(self):
if self in self.__parent._cursors:
del self.__parent._cursors[self]
self.cursor.close()
+
def __getattr__(self, key):
return getattr(self.cursor, key)
-
+
class SingletonThreadPool(Pool):
- """Maintains one connection per each thread, never moving a connection to a thread
- other than the one which it was created in.
-
- this is used for SQLite, which both does not handle multithreading by default,
- and also requires a singleton connection if a :memory: database is being used.
-
- options are the same as those of Pool, as well as:
-
- pool_size=5 - the number of threads in which to maintain connections at once."""
+ """Maintain one connection per each thread, never moving a
+ connection to a thread other than the one which it was created in.
+
+ This is used for SQLite, which both does not handle multithreading
+ by default, and also requires a singleton connection if a :memory:
+ database is being used.
+
+ Options are the same as those of Pool, as well as:
+
+ pool_size : 5
+ The number of threads in which to maintain connections at once.
+ """
+
def __init__(self, creator, pool_size=5, **params):
params['use_threadlocal'] = True
Pool.__init__(self, creator, **params)
@@ -291,13 +344,13 @@ class SingletonThreadPool(Pool):
# sqlite won't even let you close a conn from a thread that didn't create it
pass
del self._conns[key]
-
+
def dispose_local(self):
try:
del self._conns[thread.get_ident()]
except KeyError:
pass
-
+
def cleanup(self):
for key in self._conns.keys():
try:
@@ -306,13 +359,13 @@ class SingletonThreadPool(Pool):
pass
if len(self._conns) <= self.size:
return
-
+
def status(self):
return "SingletonThreadPool id:%d thread:%d size: %d" % (id(self), thread.get_ident(), len(self._conns))
def do_return_conn(self, conn):
pass
-
+
def do_get(self):
try:
return self._conns[thread.get_ident()]
@@ -322,35 +375,45 @@ class SingletonThreadPool(Pool):
if len(self._conns) > self.size:
self.cleanup()
return c
-
+
class QueuePool(Pool):
- """uses Queue.Queue to maintain a fixed-size list of connections.
-
- Arguments include all those used by the base Pool class, as well as:
-
- pool_size=5 : the size of the pool to be maintained. This is the
- largest number of connections that will be kept persistently in the pool. Note that the
- pool begins with no connections; once this number of connections is requested, that
- number of connections will remain.
-
- max_overflow=10 : the maximum overflow size of the pool. When the number of checked-out
- connections reaches the size set in pool_size, additional connections will be returned up
- to this limit. When those additional connections are returned to the pool, they are
- disconnected and discarded. It follows then that the total number of simultaneous
- connections the pool will allow is pool_size + max_overflow, and the total number of
- "sleeping" connections the pool will allow is pool_size. max_overflow can be set to -1 to
- indicate no overflow limit; no limit will be placed on the total number of concurrent
- connections.
-
- timeout=30 : the number of seconds to wait before giving up on returning a connection
+ """Use ``Queue.Queue`` to maintain a fixed-size list of connections.
+
+ Arguments include all those used by the base Pool class, as well
+ as:
+
+ pool_size
+ The size of the pool to be maintained. This is the largest
+ number of connections that will be kept persistently in the
+ pool. Note that the pool begins with no connections; once this
+ number of connections is requested, that number of connections
+ will remain. Defaults to 5.
+
+ max_overflow
+ The maximum overflow size of the pool. When the number of
+ checked-out connections reaches the size set in pool_size,
+ additional connections will be returned up to this limit. When
+ those additional connections are returned to the pool, they are
+ disconnected and discarded. It follows then that the total
+ number of simultaneous connections the pool will allow is
+ pool_size + `max_overflow`, and the total number of "sleeping"
+ connections the pool will allow is pool_size. `max_overflow` can
+ be set to -1 to indicate no overflow limit; no limit will be
+ placed on the total number of concurrent connections. Defaults
+ to 10.
+
+ timeout
+ The number of seconds to wait before giving up on returning a
+ connection. Defaults to 30.
"""
+
def __init__(self, creator, pool_size = 5, max_overflow = 10, timeout=30, **params):
Pool.__init__(self, creator, **params)
self._pool = Queue.Queue(pool_size)
self._overflow = 0 - pool_size
self._max_overflow = max_overflow
self._timeout = timeout
-
+
def do_return_conn(self, conn):
try:
self._pool.put(conn, False)
@@ -374,29 +437,33 @@ class QueuePool(Pool):
conn.close()
except Queue.Empty:
break
-
+
self._overflow = 0 - self.size()
self.log("Pool disposed. " + self.status())
def status(self):
tup = (self.size(), self.checkedin(), self.overflow(), self.checkedout())
return "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup
-
+
def size(self):
return self._pool.maxsize
-
+
def checkedin(self):
return self._pool.qsize()
-
+
def overflow(self):
return self._overflow
-
+
def checkedout(self):
return self._pool.maxsize - self._pool.qsize() + self._overflow
class NullPool(Pool):
- """a Pool implementation which does not pool connections; instead
- it literally opens and closes the underlying DBAPI connection per each connection open/close."""
+ """A Pool implementation which does not pool connections.
+
+ Instead it literally opens and closes the underlying DBAPI
+ connection per each connection open/close.
+ """
+
def status(self):
return "NullPool"
@@ -407,14 +474,19 @@ class NullPool(Pool):
pass
def do_get(self):
- return self.create_connection()
+ return self.create_connection()
class AssertionPool(Pool):
- """a Pool implementation which will raise an exception
- if more than one connection is checked out at a time. Useful for debugging
- code that is using more connections than desired.
-
- TODO: modify this to handle an arbitrary connection count."""
+ """A Pool implementation that allows at most one checked out
+ connection at a time.
+
+ This will raise an exception if more than one connection is
+ checked out at a time. Useful for debugging code that is using
+ more connections than desired.
+ """
+
+ ## TODO: modify this to handle an arbitrary connection count.
+
def __init__(self, creator, **params):
Pool.__init__(self, creator, **params)
self.connection = _ConnectionRecord(self)
@@ -438,17 +510,25 @@ class AssertionPool(Pool):
c = self.connection
self.connection = None
return c
-
+
class _DBProxy(object):
- """proxies a DBAPI2 connect() call to a pooled connection keyed to the specific connect
- parameters. other attributes are proxied through via __getattr__."""
-
+ """Proxy a DBAPI2 connect() call to a pooled connection keyed to
+ the specific connect parameters. Other attributes are proxied
+ through via __getattr__.
+ """
+
def __init__(self, module, poolclass = QueuePool, **params):
+ """Initialize a new proxy.
+
+ module
+ a DBAPI2 module.
+
+ poolclass
+ a Pool class, defaulting to QueuePool.
+
+ Other parameters are sent to the Pool object's constructor.
"""
- module is a DBAPI2 module
- poolclass is a Pool class, defaulting to QueuePool.
- other parameters are sent to the Pool object's constructor.
- """
+
self.module = module
self.params = params
self.poolclass = poolclass
@@ -460,10 +540,10 @@ class _DBProxy(object):
def __del__(self):
self.close()
-
+
def __getattr__(self, key):
return getattr(self.module, key)
-
+
def get_pool(self, *args, **params):
key = self._serialize(*args, **params)
try:
@@ -472,24 +552,32 @@ class _DBProxy(object):
pool = self.poolclass(lambda: self.module.connect(*args, **params), **self.params)
self.pools[key] = pool
return pool
-
+
def connect(self, *args, **params):
- """connects to a database using this DBProxy's module and the given connect
- arguments. if the arguments match an existing pool, the connection will be returned
- from the pool's current thread-local connection instance, or if there is no
- thread-local connection instance it will be checked out from the set of pooled
- connections. If the pool has no available connections and allows new connections to
- be created, a new database connection will be made."""
+ """Activate a connection to the database.
+
+ Connect to the database using this DBProxy's module and the
+ given connect arguments. If the arguments match an existing
+ pool, the connection will be returned from the pool's current
+ thread-local connection instance, or if there is no
+ thread-local connection instance it will be checked out from
+ the set of pooled connections.
+
+ If the pool has no available connections and allows new
+ connections to be created, a new database connection will be
+ made.
+ """
+
return self.get_pool(*args, **params).connect()
-
+
def dispose(self, *args, **params):
- """disposes the connection pool referenced by the given connect arguments."""
+ """Dispose the connection pool referenced by the given connect arguments."""
+
key = self._serialize(*args, **params)
try:
del self.pools[key]
except KeyError:
pass
-
+
def _serialize(self, *args, **params):
return pickle.dumps([args, params])
-
diff --git a/lib/sqlalchemy/queue.py b/lib/sqlalchemy/queue.py
index 49bb4badf3..7ef1ba61bd 100644
--- a/lib/sqlalchemy/queue.py
+++ b/lib/sqlalchemy/queue.py
@@ -1,9 +1,11 @@
-"""an adaptation of Py2.3/2.4's Queue module which supports reentrant behavior,
-using RLock instead of Lock for its mutex object.
-this is to support the connection pool's usage of __del__ to return connections
-to the underlying Queue, which can apparently in extremely rare cases be invoked
-within the get() method of the Queue itself, producing a put() inside the get()
-and therefore a reentrant condition."""
+"""An adaptation of Py2.3/2.4's Queue module which supports reentrant
+behavior, using RLock instead of Lock for its mutex object.
+
+This is to support the connection pool's usage of ``__del__`` to return
+connections to the underlying Queue, which can apparently in extremely
+rare cases be invoked within the ``get()`` method of the Queue itself,
+producing a ``put()`` inside the ``get()`` and therefore a reentrant
+condition."""
from time import time as _time
@@ -20,18 +22,21 @@ __all__ = ['Empty', 'Full', 'Queue']
class Empty(Exception):
"Exception raised by Queue.get(block=0)/get_nowait()."
+
pass
class Full(Exception):
"Exception raised by Queue.put(block=0)/put_nowait()."
+
pass
class Queue:
def __init__(self, maxsize=0):
"""Initialize a queue object with a given maximum size.
- If maxsize is <= 0, the queue size is infinite.
+ If `maxsize` is <= 0, the queue size is infinite.
"""
+
try:
import threading
except ImportError:
@@ -51,6 +56,7 @@ class Queue:
def qsize(self):
"""Return the approximate size of the queue (not reliable!)."""
+
self.mutex.acquire()
n = self._qsize()
self.mutex.release()
@@ -58,6 +64,7 @@ class Queue:
def empty(self):
"""Return True if the queue is empty, False otherwise (not reliable!)."""
+
self.mutex.acquire()
n = self._empty()
self.mutex.release()
@@ -65,6 +72,7 @@ class Queue:
def full(self):
"""Return True if the queue is full, False otherwise (not reliable!)."""
+
self.mutex.acquire()
n = self._full()
self.mutex.release()
@@ -73,14 +81,16 @@ class Queue:
def put(self, item, block=True, timeout=None):
"""Put an item into the queue.
- If optional args 'block' is true and 'timeout' is None (the default),
- block if necessary until a free slot is available. If 'timeout' is
- a positive number, it blocks at most 'timeout' seconds and raises
- the Full exception if no free slot was available within that time.
- Otherwise ('block' is false), put an item on the queue if a free slot
- is immediately available, else raise the Full exception ('timeout'
- is ignored in that case).
+ If optional args `block` is True and `timeout` is None (the
+ default), block if necessary until a free slot is
+ available. If `timeout` is a positive number, it blocks at
+ most `timeout` seconds and raises the ``Full`` exception if no
+ free slot was available within that time. Otherwise (`block`
+ is false), put an item on the queue if a free slot is
+ immediately available, else raise the ``Full`` exception
+ (`timeout` is ignored in that case).
"""
+
self.not_full.acquire()
try:
if not block:
@@ -107,21 +117,22 @@ class Queue:
"""Put an item into the queue without blocking.
Only enqueue the item if a free slot is immediately available.
- Otherwise raise the Full exception.
+ Otherwise raise the ``Full`` exception.
"""
return self.put(item, False)
def get(self, block=True, timeout=None):
"""Remove and return an item from the queue.
- If optional args 'block' is true and 'timeout' is None (the default),
- block if necessary until an item is available. If 'timeout' is
- a positive number, it blocks at most 'timeout' seconds and raises
- the Empty exception if no item was available within that time.
- Otherwise ('block' is false), return an item if one is immediately
- available, else raise the Empty exception ('timeout' is ignored
- in that case).
+ If optional args `block` is True and `timeout` is None (the
+ default), block if necessary until an item is available. If
+ `timeout` is a positive number, it blocks at most `timeout`
+ seconds and raises the ``Empty`` exception if no item was
+ available within that time. Otherwise (`block` is false),
+ return an item if one is immediately available, else raise the
+ ``Empty`` exception (`timeout` is ignored in that case).
"""
+
self.not_empty.acquire()
try:
if not block:
@@ -149,8 +160,9 @@ class Queue:
"""Remove and return an item from the queue without blocking.
Only get an item if one is immediately available. Otherwise
- raise the Empty exception.
+ raise the ``Empty`` exception.
"""
+
return self.get(False)
# Override these methods to implement other queue organizations
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index debde5754e..ae4a0b1626 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -4,16 +4,19 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""the schema module provides the building blocks for database metadata. This means
-all the entities within a SQL database that we might want to look at, modify, or create
-and delete are described by these objects, in a database-agnostic way.
+"""The schema module provides the building blocks for database metadata.
-A structure of SchemaItems also provides a "visitor" interface which is the primary
-method by which other methods operate upon the schema. The SQL package extends this
-structure with its own clause-specific objects as well as the visitor interface, so that
-the schema package "plugs in" to the SQL package.
+This means all the entities within a SQL database that we might want
+to look at, modify, or create and delete are described by these
+objects, in a database-agnostic way.
+A structure of SchemaItems also provides a *visitor* interface which is
+the primary method by which other methods operate upon the schema.
+The SQL package extends this structure with its own clause-specific
+objects as well as the visitor interface, so that the schema package
+*plugs in* to the SQL package.
"""
+
from sqlalchemy import sql, types, exceptions,util, databases
import sqlalchemy
import copy, re, string
@@ -23,50 +26,66 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', '
'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
class SchemaItem(object):
- """base class for items that define a database schema."""
+ """Base class for items that define a database schema."""
+
def _init_items(self, *args):
- """initialize the list of child items for this SchemaItem"""
+ """Initialize the list of child items for this SchemaItem."""
+
for item in args:
if item is not None:
item._set_parent(self)
+
def _get_parent(self):
raise NotImplementedError()
+
def _set_parent(self, parent):
- """associate with this SchemaItem's parent object."""
+ """Associate with this SchemaItem's parent object."""
+
raise NotImplementedError()
+
def __repr__(self):
return "%s()" % self.__class__.__name__
+
def _derived_metadata(self):
- """return the the MetaData to which this item is bound"""
+ """Return the the MetaData to which this item is bound."""
+
return None
+
def _get_engine(self):
- """return the engine or None if no engine"""
+ """Return the engine or None if no engine."""
+
return self._derived_metadata().engine
+
def get_engine(self):
- """return the engine or raise an error if no engine"""
+ """Return the engine or raise an error if no engine."""
+
e = self._get_engine()
if e is not None:
return e
else:
raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine")
-
+
def _set_casing_strategy(self, name, kwargs, keyname='case_sensitive'):
- """set the "case_sensitive" argument sent via keywords to the item's constructor.
-
- for the purposes of Table's 'schema' property, the name of the variable is
- optionally configurable."""
+ """Set the "case_sensitive" argument sent via keywords to the item's constructor.
+
+ For the purposes of Table's 'schema' property, the name of the
+ variable is optionally configurable.
+ """
setattr(self, '_%s_setting' % keyname, kwargs.pop(keyname, None))
+
def _determine_case_sensitive(self, name, keyname='case_sensitive'):
- """determine the "case_sensitive" value for this item.
-
- for the purposes of Table's 'schema' property, the name of the variable is
- optionally configurable.
-
- a local non-None value overrides all others. after that, the parent item
- (i.e. Column for a Sequence, Table for a Column, MetaData for a Table) is
- searched for a non-None setting, traversing each parent until none are found.
+ """Determine the `case_sensitive` value for this item.
+
+ For the purposes of Table's `schema` property, the name of the
+ variable is optionally configurable.
+
+ A local non-None value overrides all others. After that, the
+ parent item (i.e. ``Column`` for a ``Sequence``, ``Table`` for
+ a ``Column``, ``MetaData`` for a ``Table``) is searched for a
+ non-None setting, traversing each parent until none are found.
finally, case_sensitive is set to True as a default.
"""
+
local = getattr(self, '_%s_setting' % keyname, None)
if local is not None:
return local
@@ -77,7 +96,8 @@ class SchemaItem(object):
parentval = getattr(parent, '_case_sensitive_setting', None)
if parentval is not None:
return parentval
- return True
+ return True
+
def _get_case_sensitive(self):
try:
return self.__case_sensitive
@@ -85,18 +105,19 @@ class SchemaItem(object):
self.__case_sensitive = self._determine_case_sensitive(self.name)
return self.__case_sensitive
case_sensitive = property(_get_case_sensitive)
-
+
engine = property(lambda s:s._get_engine())
metadata = property(lambda s:s._derived_metadata())
-
+
def _get_table_key(name, schema):
if schema is None:
return name
else:
return schema + "." + name
-
+
class _TableSingleton(type):
- """a metaclass used by the Table object to provide singleton behavior."""
+ """A metaclass used by the ``Table`` object to provide singleton behavior."""
+
def __call__(self, name, metadata, *args, **kwargs):
if isinstance(metadata, sql.Executor):
# backwards compatibility - get a BoundSchema associated with the engine
@@ -109,10 +130,10 @@ class _TableSingleton(type):
args = list(args)
args.insert(0, metadata)
metadata = None
-
+
if metadata is None:
metadata = default_metadata
-
+
name = str(name) # in case of incoming unicode
schema = kwargs.get('schema', None)
autoload = kwargs.pop('autoload', False)
@@ -148,54 +169,90 @@ class _TableSingleton(type):
table._init_items(*args)
return table
-
+
class Table(SchemaItem, sql.TableClause):
- """represents a relational database table. This subclasses sql.TableClause to provide
- a table that is "wired" to an engine. Whereas TableClause represents a table as its
- used in a SQL expression, Table represents a table as its created in the database.
-
- Be sure to look at sqlalchemy.sql.TableImpl for additional methods defined on a Table."""
+ """Represent a relational database table.
+
+ This subclasses ``sql.TableClause`` to provide a table that is
+ *wired* to an engine.
+
+ Whereas ``TableClause`` represents a table as its used in a SQL
+ expression, ``Table`` represents a table as its created in the
+ database.
+
+ Be sure to look at ``sqlalchemy.sql.TableImpl`` for additional methods
+ defined on a ``Table``."""
+
__metaclass__ = _TableSingleton
-
+
def __init__(self, name, metadata, **kwargs):
"""Construct a Table.
-
- Table objects can be constructed directly. The init method is actually called via
- the TableSingleton metaclass. Arguments are:
-
- name : the name of this table, exactly as it appears, or will appear, in the database.
- This property, along with the "schema", indicates the "singleton identity" of this table.
- Further tables constructed with the same name/schema combination will return the same
- Table instance.
-
- *args : should contain a listing of the Column objects for this table.
-
- **kwargs : options include:
-
- schema=None : the "schema name" for this table, which is required if the table resides in a
- schema other than the default selected schema for the engine's database connection.
-
- autoload=False : the Columns for this table should be reflected from the database. Usually
- there will be no Column objects in the constructor if this property is set.
-
- mustexist=False : indicates that this Table must already have been defined elsewhere in the application,
- else an exception is raised.
-
- useexisting=False : indicates that if this Table was already defined elsewhere in the application, disregard
- the rest of the constructor arguments.
-
- owner=None : optional owning user of this table. useful for databases such as Oracle to aid in table
- reflection.
-
- quote=False : indicates that the Table identifier must be properly escaped and quoted before being sent
- to the database. This flag overrides all other quoting behavior.
-
- quote_schema=False : indicates that the Namespace identifier must be properly escaped and quoted before being sent
- to the database. This flag overrides all other quoting behavior.
-
- case_sensitive=True : indicates quoting should be used if the identifier contains mixed case.
-
- case_sensitive_schema=True : indicates quoting should be used if the identifier contains mixed case.
+
+ Table objects can be constructed directly. The init method is
+ actually called via the TableSingleton metaclass. Arguments
+ are:
+
+ name
+ The name of this table, exactly as it appears, or will
+ appear, in the database.
+
+ This property, along with the *schema*, indicates the
+ *singleton identity* of this table.
+
+ Further tables constructed with the same name/schema
+ combination will return the same Table instance.
+
+ *args
+ Should contain a listing of the Column objects for this table.
+
+ **kwargs
+ options include:
+
+ schema
+ Defaults to None: the *schema name* for this table, which is
+ required if the table resides in a schema other than the
+ default selected schema for the engine's database
+ connection.
+
+ autoload
+ Defaults to False: the Columns for this table should be
+ reflected from the database. Usually there will be no
+ Column objects in the constructor if this property is set.
+
+ mustexist
+ Defaults to False: indicates that this Table must already
+ have been defined elsewhere in the application, else an
+ exception is raised.
+
+ useexisting
+ Defaults to False: indicates that if this Table was
+ already defined elsewhere in the application, disregard
+ the rest of the constructor arguments.
+
+ owner
+ Defaults to None: optional owning user of this table.
+ useful for databases such as Oracle to aid in table
+ reflection.
+
+ quote
+ Defaults to False: indicates that the Table identifier
+ must be properly escaped and quoted before being sent to
+ the database. This flag overrides all other quoting
+ behavior.
+
+ quote_schema
+ Defaults to False: indicates that the Namespace identifier
+ must be properly escaped and quoted before being sent to
+ the database. This flag overrides all other quoting
+ behavior.
+
+ case_sensitive
+ Defaults to True: indicates quoting should be used if the
+ identifier contains mixed case.
+
+ case_sensitive_schema
+ Defaults to True: indicates quoting should be used if the
+ identifier contains mixed case.
"""
super(Table, self).__init__(name)
self._metadata = metadata
@@ -213,13 +270,13 @@ class Table(SchemaItem, sql.TableClause):
self._set_casing_strategy(name, kwargs)
self._set_casing_strategy(self.schema or '', kwargs, keyname='case_sensitive_schema')
-
+
if len([k for k in kwargs if not re.match(r'^(?:%s)_' % '|'.join(databases.__all__), k)]):
raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
-
+
# store extra kwargs, which should only contain db-specific options
self.kwargs = kwargs
-
+
def _get_case_sensitive_schema(self):
try:
return getattr(self, '_case_sensitive_schema')
@@ -234,39 +291,46 @@ class Table(SchemaItem, sql.TableClause):
self._primary_key = pk
self.constraints.add(pk)
primary_key = property(lambda s:s._primary_key, _set_primary_key)
-
+
def _derived_metadata(self):
return self._metadata
+
def __repr__(self):
return "Table(%s)" % string.join(
- [repr(self.name)] + [repr(self.metadata)] +
- [repr(x) for x in self.columns] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]
- , ',')
-
+ [repr(self.name)] + [repr(self.metadata)] +
+ [repr(x) for x in self.columns] +
+ ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']]
+ , ',')
+
def __str__(self):
return _get_table_key(self.name, self.schema)
def append_column(self, column):
- """append a Column to this Table."""
+ """Append a ``Column`` to this ``Table``."""
+
column._set_parent(self)
+
def append_constraint(self, constraint):
- """append a Constraint to this Table."""
+ """Append a ``Constraint`` to this ``Table``."""
+
constraint._set_parent(self)
def _get_parent(self):
- return self._metadata
+ return self._metadata
+
def _set_parent(self, metadata):
metadata.tables[_get_table_key(self.name, self.schema)] = self
self._metadata = metadata
- def accept_schema_visitor(self, visitor, traverse=True):
+
+ def accept_schema_visitor(self, visitor, traverse=True):
if traverse:
for c in self.columns:
c.accept_schema_visitor(visitor, True)
return visitor.visit_table(self)
def exists(self, connectable=None):
- """return True if this table exists."""
+ """Return True if this table exists."""
+
if connectable is None:
connectable = self.get_engine()
@@ -276,17 +340,22 @@ class Table(SchemaItem, sql.TableClause):
return connectable.run_callable(do)
def create(self, connectable=None, checkfirst=False):
- """issue a CREATE statement for this table.
-
- see also metadata.create_all()."""
+ """Issue a ``CREATE`` statement for this table.
+
+ See also ``metadata.create_all()``."""
+
self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self])
+
def drop(self, connectable=None, checkfirst=False):
- """issue a DROP statement for this table.
-
- see also metadata.drop_all()."""
+ """Issue a ``DROP`` statement for this table.
+
+ See also ``metadata.drop_all()``."""
+
self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self])
+
def tometadata(self, metadata, schema=None):
- """return a copy of this Table associated with a different MetaData."""
+ """Return a copy of this ``Table`` associated with a different ``MetaData``."""
+
try:
if schema is None:
schema = self.schema
@@ -301,70 +370,113 @@ class Table(SchemaItem, sql.TableClause):
return Table(self.name, metadata, schema=schema, *args)
class Column(SchemaItem, sql._ColumnClause):
- """represents a column in a database table. this is a subclass of sql.ColumnClause and
- represents an actual existing table in the database, in a similar fashion as TableClause/Table."""
+ """Represent a column in a database table.
+
+ This is a subclass of ``sql.ColumnClause`` and represents an
+ actual existing table in the database, in a similar fashion as
+ ``TableClause``/``Table``.
+ """
+
def __init__(self, name, type, *args, **kwargs):
- """constructs a new Column object. Arguments are:
-
- name : the name of this column. this should be the identical name as it appears,
- or will appear, in the database.
-
- type: the TypeEngine for this column.
- This can be any subclass of types.AbstractType, including the database-agnostic types defined
- in the types module, database-specific types defined within specific database modules, or user-defined types.
-
- type: the TypeEngine for this column. This can be any subclass of types.AbstractType, including
- the database-agnostic types defined in the types module, database-specific types defined within
- specific database modules, or user-defined types. If the column contains a ForeignKey,
- the type can also be None, in which case the type assigned will be that of the referenced column.
-
- *args: Constraint, ForeignKey, ColumnDefault and Sequence objects should be added as list values.
-
- **kwargs : keyword arguments include:
-
- key=None : an optional "alias name" for this column. The column will then be identified everywhere
- in an application, including the column list on its Table, by this key, and not the given name.
- Generated SQL, however, will still reference the column by its actual name.
-
- primary_key=False : True if this column is a primary key column. Multiple columns can have this flag
- set to specify composite primary keys. As an alternative, the primary key of a Table can be specified
- via an explicit PrimaryKeyConstraint instance appended to the Table's list of objects.
-
- nullable=True : True if this column should allow nulls. Defaults to True unless this column is a primary
- key column.
-
- default=None : a scalar, python callable, or ClauseElement representing the "default value" for this column,
- which will be invoked upon insert if this column is not present in the insert list or is given a value
- of None. The default expression will be converted into a ColumnDefault object upon initialization.
-
- _is_oid=False : used internally to indicate that this column is used as the quasi-hidden "oid" column
-
- index=False : Indicates that this column is
- indexed. The name of the index is autogenerated.
- to specify indexes with explicit names or indexes that contain multiple
- columns, use the Index construct instead.
-
- unique=False : Indicates that this column
- contains a unique constraint, or if index=True as well, indicates
- that the Index should be created with the unique flag.
- To specify multiple columns in the constraint/index or to specify an
- explicit name, use the UniqueConstraint or Index constructs instead.
-
- autoincrement=True : Indicates that integer-based primary key columns should have autoincrementing behavior,
- if supported by the underlying database. This will affect CREATE TABLE statements such that they will
- use the databases "auto-incrementing" keyword (such as SERIAL for postgres, AUTO_INCREMENT for mysql) and will
- also affect the behavior of some dialects during INSERT statement execution such that they will assume primary
- key values are created in this manner. If a Column has an explicit ColumnDefault object (such as via the
- "default" keyword, or a Sequence or PassiveDefault), then the value of autoincrement is ignored and is assumed
- to be False. autoincrement value is only significant for a column with a type or subtype of Integer.
-
- quote=False : indicates that the Column identifier must be properly escaped and quoted before being sent
- to the database. This flag should normally not be required as dialects can auto-detect conditions where quoting
- is required.
-
- case_sensitive=True : indicates quoting should be used if the identifier contains mixed case.
+ """Construct a new ``Column`` object.
+
+ Arguments are:
+
+ name
+ The name of this column. This should be the identical name
+ as it appears, or will appear, in the database.
+
+ type
+ The ``TypeEngine`` for this column. This can be any
+ subclass of ``types.AbstractType``, including the
+ database-agnostic types defined in the types module,
+ database-specific types defined within specific database
+ modules, or user-defined types. If the column contains a
+ ForeignKey, the type can also be None, in which case the
+ type assigned will be that of the referenced column.
+
+ *args
+ Constraint, ForeignKey, ColumnDefault and Sequence objects
+ should be added as list values.
+
+ **kwargs
+ Keyword arguments include:
+
+ key
+ Defaults to None: an optional *alias name* for this column.
+ The column will then be identified everywhere in an
+ application, including the column list on its Table, by
+ this key, and not the given name. Generated SQL, however,
+ will still reference the column by its actual name.
+
+ primary_key
+ Defaults to False: True if this column is a primary key
+ column. Multiple columns can have this flag set to
+ specify composite primary keys. As an alternative, the
+ primary key of a Table can be specified via an explicit
+ ``PrimaryKeyConstraint`` instance appended to the Table's
+ list of objects.
+
+ nullable
+ Defaults to True : True if this column should allow
+ nulls. True is the default unless this column is a primary
+ key column.
+
+ default
+ Defaults to None: a scalar, Python callable, or ``ClauseElement``
+ representing the *default value* for this column, which will
+ be invoked upon insert if this column is not present in
+ the insert list or is given a value of None. The default
+ expression will be converted into a ``ColumnDefault`` object
+ upon initialization.
+
+ _is_oid
+ Defaults to False: used internally to indicate that this
+ column is used as the quasi-hidden "oid" column
+
+ index
+ Defaults to False: indicates that this column is
+ indexed. The name of the index is autogenerated. to
+ specify indexes with explicit names or indexes that
+ contain multiple columns, use the ``Index`` construct instead.
+
+ unique
+ Defaults to False: indicates that this column contains a
+ unique constraint, or if `index` is True as well,
+ indicates that the Index should be created with the unique
+ flag. To specify multiple columns in the constraint/index
+ or to specify an explicit name, use the
+ ``UniqueConstraint`` or ``Index`` constructs instead.
+
+ autoincrement
+ Defaults to True: indicates that integer-based primary key
+ columns should have autoincrementing behavior, if
+ supported by the underlying database. This will affect
+ ``CREATE TABLE`` statements such that they will use the
+ databases *auto-incrementing* keyword (such as ``SERIAL``
+ for Postgres, ``AUTO_INCREMENT`` for Mysql) and will also
+ affect the behavior of some dialects during ``INSERT``
+ statement execution such that they will assume primary key
+ values are created in this manner. If a ``Column`` has an
+ explicit ``ColumnDefault`` object (such as via the `default`
+ keyword, or a ``Sequence`` or ``PassiveDefault``), then
+ the value of `autoincrement` is ignored and is assumed to be
+ False. `autoincrement` value is only significant for a
+ column with a type or subtype of Integer.
+
+ quote
+ Defaults to False: indicates that the Column identifier
+ must be properly escaped and quoted before being sent to
+ the database. This flag should normally not be required
+ as dialects can auto-detect conditions where quoting is
+ required.
+
+ case_sensitive
+ Defaults to True: indicates quoting should be used if the
+ identifier contains mixed case.
"""
- name = str(name) # in case of incoming unicode
+
+ name = str(name) # in case of incoming unicode
super(Column, self).__init__(name, None, type)
self.args = args
self.key = kwargs.pop('key', name)
@@ -396,15 +508,16 @@ class Column(SchemaItem, sql._ColumnClause):
return self.name
else:
return self.name
-
+
def _derived_metadata(self):
return self.table.metadata
+
def _get_engine(self):
return self.table.engine
-
+
def append_foreign_key(self, fk):
fk._set_parent(self)
-
+
def __repr__(self):
kwarg = []
if self.key != self.name:
@@ -418,14 +531,14 @@ class Column(SchemaItem, sql._ColumnClause):
if self.default:
kwarg.append('default')
return "Column(%s)" % string.join(
- [repr(self.name)] + [repr(self.type)] +
- [repr(x) for x in self.foreign_keys if x is not None] +
- [repr(x) for x in self.constraints] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
- , ',')
-
+ [repr(self.name)] + [repr(self.type)] +
+ [repr(x) for x in self.foreign_keys if x is not None] +
+ [repr(x) for x in self.constraints] +
+ ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
+ , ',')
+
def _get_parent(self):
- return self.table
+ return self.table
def _set_parent(self, table):
if getattr(self, 'table', None) is not None:
@@ -448,7 +561,7 @@ class Column(SchemaItem, sql._ColumnClause):
if isinstance(self.unique, str):
raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only. To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint or Index object to the Table's list of elements.")
table.append_constraint(UniqueConstraint(self.key))
-
+
toinit = list(self.args)
if self.default is not None:
toinit.append(ColumnDefault(self.default))
@@ -457,15 +570,21 @@ class Column(SchemaItem, sql._ColumnClause):
self._init_items(*toinit)
self.args = None
- def copy(self):
- """creates a copy of this Column, unitialized. this is used in Table.tometadata."""
+ def copy(self):
+ """Create a copy of this ``Column``, unitialized.
+
+ This is used in ``Table.tometadata``.
+ """
+
return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, case_sensitive=self._case_sensitive_setting, quote=self.quote, *[c.copy() for c in self.constraints])
-
+
def _make_proxy(self, selectable, name = None):
- """create a "proxy" for this column.
-
- This is a copy of this Column referenced
- by a different parent (such as an alias or select statement)"""
+ """Create a *proxy* for this column.
+
+ This is a copy of this ``Column`` referenced by a different parent
+ (such as an alias or select statement).
+ """
+
fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
c.table = selectable
@@ -479,14 +598,16 @@ class Column(SchemaItem, sql._ColumnClause):
return c
def _case_sens(self):
- """redirect the 'case_sensitive' accessor to use the ultimate parent column which created
- this one."""
+ """Redirect the `case_sensitive` accessor to use the ultimate
+ parent column which created this one."""
+
return self.__originating_column._get_case_sensitive()
case_sensitive = property(_case_sens, lambda s,v:None)
-
+
def accept_schema_visitor(self, visitor, traverse=True):
- """traverses the given visitor to this Column's default and foreign key object,
- then calls visit_column on the visitor."""
+ """Traverse the given visitor to this ``Column``'s default and foreign key object,
+ then call `visit_column` on the visitor."""
+
if traverse:
if self.default is not None:
self.default.accept_schema_visitor(visitor, traverse=True)
@@ -500,22 +621,29 @@ class Column(SchemaItem, sql._ColumnClause):
class ForeignKey(SchemaItem):
- """defines a column-level ForeignKey constraint between two columns.
-
- ForeignKey is specified as an argument to a Column object.
-
- One or more ForeignKey objects are used within a ForeignKeyConstraint
- object which represents the table-level constraint definition."""
+ """Defines a column-level ``ForeignKey`` constraint between two columns.
+
+ ``ForeignKey`` is specified as an argument to a Column object.
+
+ One or more ``ForeignKey`` objects are used within a
+ ``ForeignKeyConstraint`` object which represents the table-level
+ constraint definition.
+ """
+
def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None):
- """Construct a new ForeignKey object.
-
- "column" can be a schema.Column object representing the relationship,
- or just its string name given as "tablename.columnname". schema can be
- specified as "schema.tablename.columnname"
-
- "constraint" is the owning ForeignKeyConstraint object, if any. if not given,
- then a ForeignKeyConstraint will be automatically created and added to the parent table.
+ """Construct a new ``ForeignKey`` object.
+
+ column
+ Can be a ``schema.Column`` object representing the relationship,
+ or just its string name given as ``tablename.columnname``.
+ schema can be specified as ``schema.tablename.columnname``.
+
+ constraint
+ Is the owning ``ForeignKeyConstraint`` object, if any. if not
+ given, then a ``ForeignKeyConstraint`` will be automatically
+ created and added to the parent table.
"""
+
if isinstance(column, unicode):
column = str(column)
self._colspec = column
@@ -525,14 +653,15 @@ class ForeignKey(SchemaItem):
self.name = name
self.onupdate = onupdate
self.ondelete = ondelete
-
+
def __repr__(self):
return "ForeignKey(%s)" % repr(self._get_colspec())
-
+
def copy(self):
- """produce a copy of this ForeignKey object."""
+ """Produce a copy of this ForeignKey object."""
+
return ForeignKey(self._get_colspec())
-
+
def _get_colspec(self):
if isinstance(self._colspec, str):
return self._colspec
@@ -540,17 +669,18 @@ class ForeignKey(SchemaItem):
return "%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.key)
else:
return "%s.%s" % (self._colspec.table.name, self._colspec.key)
-
+
def references(self, table):
- """returns True if the given table is referenced by this ForeignKey."""
+ """Return True if the given table is referenced by this ``ForeignKey``."""
+
return table.corresponding_column(self.column, False) is not None
-
+
def _init_column(self):
# ForeignKey inits its remote column as late as possible, so tables can
# be defined without dependencies
if self._column is None:
if isinstance(self._colspec, str):
- # locate the parent table this foreign key is attached to.
+ # locate the parent table this foreign key is attached to.
# we use the "original" column which our parent column represents
# (its a list of columns/other ColumnElements if the parent table is a UNION)
for c in self.parent.orig_set:
@@ -582,15 +712,17 @@ class ForeignKey(SchemaItem):
if self.parent.type is types.NULLTYPE:
self.parent.type = self._column.type
return self._column
-
+
column = property(lambda s: s._init_column())
def accept_schema_visitor(self, visitor, traverse=True):
- """calls the visit_foreign_key method on the given visitor."""
+ """Call the `visit_foreign_key` method on the given visitor."""
+
visitor.visit_foreign_key(self)
-
+
def _get_parent(self):
return self.parent
+
def _set_parent(self, column):
self.parent = column
@@ -603,17 +735,21 @@ class ForeignKey(SchemaItem):
self.parent.table.foreign_keys.add(self)
class DefaultGenerator(SchemaItem):
- """Base class for column "default" values."""
+ """Base class for column *default* values."""
+
def __init__(self, for_update=False, metadata=None):
self.for_update = for_update
self._metadata = metadata
+
def _derived_metadata(self):
try:
return self.column.table.metadata
except AttributeError:
return self._metadata
+
def _get_parent(self):
return getattr(self, 'column', None)
+
def _set_parent(self, column):
self.column = column
self._metadata = self.column.table.metadata
@@ -621,38 +757,51 @@ class DefaultGenerator(SchemaItem):
self.column.onupdate = self
else:
self.column.default = self
+
def execute(self, **kwargs):
return self.get_engine().execute_default(self, **kwargs)
+
def __repr__(self):
return "DefaultGenerator()"
class PassiveDefault(DefaultGenerator):
- """a default that takes effect on the database side"""
+ """A default that takes effect on the database side."""
+
def __init__(self, arg, **kwargs):
super(PassiveDefault, self).__init__(**kwargs)
self.arg = arg
+
def accept_schema_visitor(self, visitor, traverse=True):
return visitor.visit_passive_default(self)
+
def __repr__(self):
return "PassiveDefault(%s)" % repr(self.arg)
-
+
class ColumnDefault(DefaultGenerator):
- """A plain default value on a column. this could correspond to a constant,
- a callable function, or a SQL clause."""
+ """A plain default value on a column.
+
+ This could correspond to a constant, a callable function, or a SQL
+ clause.
+ """
+
def __init__(self, arg, **kwargs):
super(ColumnDefault, self).__init__(**kwargs)
self.arg = arg
+
def accept_schema_visitor(self, visitor, traverse=True):
- """calls the visit_column_default method on the given visitor."""
+ """Call the visit_column_default method on the given visitor."""
+
if self.for_update:
return visitor.visit_column_onupdate(self)
else:
return visitor.visit_column_default(self)
+
def __repr__(self):
return "ColumnDefault(%s)" % repr(self.arg)
-
+
class Sequence(DefaultGenerator):
- """represents a sequence, which applies to Oracle and Postgres databases."""
+ """Represent a sequence, which applies to Oracle and Postgres databases."""
+
def __init__(self, name, start = None, increment = None, optional=False, quote=False, **kwargs):
super(Sequence, self).__init__(**kwargs)
self.name = name
@@ -661,42 +810,59 @@ class Sequence(DefaultGenerator):
self.optional=optional
self.quote = quote
self._set_casing_strategy(name, kwargs)
+
def __repr__(self):
return "Sequence(%s)" % string.join(
- [repr(self.name)] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']]
+ [repr(self.name)] +
+ ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']]
, ',')
+
def _set_parent(self, column):
super(Sequence, self)._set_parent(column)
column.sequence = self
+
def create(self):
self.get_engine().create(self)
return self
+
def drop(self):
self.get_engine().drop(self)
+
def accept_schema_visitor(self, visitor, traverse=True):
- """calls the visit_seauence method on the given visitor."""
+ """Call the visit_seauence method on the given visitor."""
+
return visitor.visit_sequence(self)
class Constraint(SchemaItem):
- """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint.
-
- Implements a hybrid of dict/setlike behavior with regards to the list of underying columns"""
+ """Represent a table-level ``Constraint`` such as a composite primary
+ key, foreign key, or unique constraint.
+
+ Implements a hybrid of dict/setlike behavior with regards to the
+ list of underying columns.
+ """
+
def __init__(self, name=None):
self.name = name
self.columns = sql.ColumnCollection()
+
def __contains__(self, x):
return x in self.columns
+
def keys(self):
return self.columns.keys()
+
def __add__(self, other):
return self.columns + other
+
def __iter__(self):
return iter(self.columns)
+
def __len__(self):
return len(self.columns)
+
def copy(self):
raise NotImplementedError()
+
def _get_parent(self):
return getattr(self, 'table', None)
@@ -704,19 +870,23 @@ class CheckConstraint(Constraint):
def __init__(self, sqltext, name=None):
super(CheckConstraint, self).__init__(name)
self.sqltext = sqltext
+
def accept_schema_visitor(self, visitor, traverse=True):
if isinstance(self.parent, Table):
visitor.visit_check_constraint(self)
else:
visitor.visit_column_check_constraint(self)
+
def _set_parent(self, parent):
self.parent = parent
parent.constraints.add(self)
+
def copy(self):
return CheckConstraint(self.sqltext, name=self.name)
-
+
class ForeignKeyConstraint(Constraint):
- """table-level foreign key constraint, represents a colleciton of ForeignKey objects."""
+ """Table-level foreign key constraint, represents a collection of ``ForeignKey`` objects."""
+
def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False):
super(ForeignKeyConstraint, self).__init__(name)
self.__colnames = columns
@@ -727,78 +897,101 @@ class ForeignKeyConstraint(Constraint):
if self.name is None and use_alter:
raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
self.use_alter = use_alter
+
def _set_parent(self, table):
self.table = table
table.constraints.add(self)
for (c, r) in zip(self.__colnames, self.__refcolnames):
self.append_element(c,r)
+
def accept_schema_visitor(self, visitor, traverse=True):
visitor.visit_foreign_key_constraint(self)
+
def append_element(self, col, refcol):
fk = ForeignKey(refcol, constraint=self)
fk._set_parent(self.table.c[col])
self._append_fk(fk)
+
def _append_fk(self, fk):
self.columns.add(self.table.c[fk.parent.key])
self.elements.add(fk)
+
def copy(self):
return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete)
-
+
class PrimaryKeyConstraint(Constraint):
def __init__(self, *columns, **kwargs):
super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None))
self.__colnames = list(columns)
+
def _set_parent(self, table):
self.table = table
table.primary_key = self
for c in self.__colnames:
self.append_column(table.c[c])
+
def accept_schema_visitor(self, visitor, traverse=True):
visitor.visit_primary_key_constraint(self)
+
def add(self, col):
self.append_column(col)
+
def remove(self, col):
col.primary_key=False
del self.columns[col.key]
+
def append_column(self, col):
self.columns.add(col)
col.primary_key=True
+
def copy(self):
return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
+
def __eq__(self, other):
return self.columns == other
-
+
class UniqueConstraint(Constraint):
def __init__(self, *columns, **kwargs):
super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None))
self.__colnames = list(columns)
+
def _set_parent(self, table):
self.table = table
table.constraints.add(self)
for c in self.__colnames:
self.append_column(table.c[c])
+
def append_column(self, col):
self.columns.add(col)
+
def accept_schema_visitor(self, visitor, traverse=True):
visitor.visit_unique_constraint(self)
+
def copy(self):
return UniqueConstraint(name=self.name, *self.__colnames)
-
+
class Index(SchemaItem):
- """Represents an index of columns from a database table
- """
+ """Represent an index of columns from a database table."""
+
def __init__(self, name, *columns, **kwargs):
- """Constructs an index object. Arguments are:
+ """Construct an index object.
+
+ Arguments are:
- name : the name of the index
+ name
+ The name of the index
- *columns : columns to include in the index. All columns must belong to
- the same table, and no column may appear more than once.
+ *columns
+ Columns to include in the index. All columns must belong to
+ the same table, and no column may appear more than once.
- **kw : keyword arguments include:
+ **kwargs
+ Keyword arguments include:
- unique=True : create a unique index
+ unique
+ Defaults to True: create a unique index.
"""
+
self.name = name
self.columns = []
self.table = None
@@ -807,11 +1000,14 @@ class Index(SchemaItem):
def _derived_metadata(self):
return self.table.metadata
+
def _init_items(self, *args):
for column in args:
self.append_column(column)
+
def _get_parent(self):
- return self.table
+ return self.table
+
def _set_parent(self, table):
self.table = table
table.indexes.add(self)
@@ -832,36 +1028,43 @@ class Index(SchemaItem):
"same index (%s already has column %s)"
% (self.name, column))
self.columns.append(column)
-
+
def create(self, connectable=None):
if connectable is not None:
connectable.create(self)
else:
self.get_engine().create(self)
return self
+
def drop(self, connectable=None):
if connectable is not None:
connectable.drop(self)
else:
self.get_engine().drop(self)
+
def accept_schema_visitor(self, visitor, traverse=True):
visitor.visit_index(self)
+
def __str__(self):
return repr(self)
+
def __repr__(self):
return 'Index("%s", %s%s)' % (self.name,
', '.join([repr(c)
for c in self.columns]),
(self.unique and ', unique=True') or '')
-
+
class MetaData(SchemaItem):
- """represents a collection of Tables and their associated schema constructs."""
+ """Represent a collection of Tables and their associated schema constructs."""
+
def __init__(self, name=None, **kwargs):
self.tables = {}
self.name = name
self._set_casing_strategy(name, kwargs)
+
def is_bound(self):
return False
+
def clear(self):
self.tables.clear()
@@ -873,64 +1076,80 @@ class MetaData(SchemaItem):
tables = util.Set(tables).intersection(self.tables.values())
sorter = sqlalchemy.sql_util.TableCollection(list(tables))
return iter(sorter.sort(reverse=reverse))
+
def _get_parent(self):
- return None
+ return None
+
def create_all(self, connectable=None, tables=None, checkfirst=True):
- """create all tables stored in this metadata.
-
- This will conditionally create tables depending on if they do not yet
- exist in the database.
-
- connectable - a Connectable used to access the database; or use the engine
- bound to this MetaData.
-
- tables - optional list of tables, which is a subset of the total
- tables in the MetaData (others are ignored)"""
+ """Create all tables stored in this metadata.
+
+ This will conditionally create tables depending on if they do
+ not yet exist in the database.
+
+ connectable
+ A ``Connectable`` used to access the database; or use the engine
+ bound to this ``MetaData``.
+
+ tables
+ Optional list of tables, which is a subset of the total
+ tables in the ``MetaData`` (others are ignored).
+ """
+
if connectable is None:
connectable = self.get_engine()
connectable.create(self, checkfirst=checkfirst, tables=tables)
-
+
def drop_all(self, connectable=None, tables=None, checkfirst=True):
- """drop all tables stored in this metadata.
-
- This will conditionally drop tables depending on if they currently
- exist in the database.
-
- connectable - a Connectable used to access the database; or use the engine
- bound to this MetaData.
-
- tables - optional list of tables, which is a subset of the total
- tables in the MetaData (others are ignored)
+ """Drop all tables stored in this metadata.
+
+ This will conditionally drop tables depending on if they
+ currently exist in the database.
+
+ connectable
+ A ``Connectable`` used to access the database; or use the engine
+ bound to this ``MetaData``.
+
+ tables
+ Optional list of tables, which is a subset of the total
+ tables in the ``MetaData`` (others are ignored).
"""
+
if connectable is None:
connectable = self.get_engine()
connectable.drop(self, checkfirst=checkfirst, tables=tables)
-
-
+
def accept_schema_visitor(self, visitor, traverse=True):
visitor.visit_metadata(self)
-
+
def _derived_metadata(self):
return self
+
def _get_engine(self):
if not self.is_bound():
return None
return self._engine
-
+
class BoundMetaData(MetaData):
- """builds upon MetaData to provide the capability to bind to an Engine implementation."""
+ """Build upon ``MetaData`` to provide the capability to bind to an
+ ``Engine`` implementation.
+ """
+
def __init__(self, engine_or_url, name=None, **kwargs):
super(BoundMetaData, self).__init__(name, **kwargs)
if isinstance(engine_or_url, basestring):
self._engine = sqlalchemy.create_engine(engine_or_url, **kwargs)
else:
self._engine = engine_or_url
+
def is_bound(self):
return True
class DynamicMetaData(MetaData):
- """builds upon MetaData to provide the capability to bind to multiple Engine implementations
- on a dynamically alterable, thread-local basis."""
+ """Build upon ``MetaData`` to provide the capability to bind to
+ multiple ``Engine`` implementations on a dynamically alterable,
+ thread-local basis.
+ """
+
def __init__(self, name=None, threadlocal=True, **kwargs):
super(DynamicMetaData, self).__init__(name, **kwargs)
if threadlocal:
@@ -938,6 +1157,7 @@ class DynamicMetaData(MetaData):
else:
self.context = self
self.__engines = {}
+
def connect(self, engine_or_url, **kwargs):
if isinstance(engine_or_url, str):
try:
@@ -950,59 +1170,80 @@ class DynamicMetaData(MetaData):
if not self.__engines.has_key(engine_or_url):
self.__engines[engine_or_url] = engine_or_url
self.context._engine = engine_or_url
+
def is_bound(self):
return hasattr(self.context, '_engine') and self.context._engine is not None
+
def dispose(self):
- """disposes all Engines to which this DynamicMetaData has been connected."""
+ """Dispose all ``Engines`` to which this ``DynamicMetaData`` has been connected."""
+
for e in self.__engines.values():
e.dispose()
+
def _get_engine(self):
if hasattr(self.context, '_engine'):
return self.context._engine
else:
return None
engine=property(_get_engine)
-
+
class SchemaVisitor(sql.ClauseVisitor):
- """defines the visiting for SchemaItem objects"""
+ """Define the visiting for ``SchemaItem`` objects."""
+
def visit_schema(self, schema):
- """visit a generic SchemaItem"""
+ """Visit a generic ``SchemaItem``."""
pass
+
def visit_table(self, table):
- """visit a Table."""
+ """Visit a ``Table``."""
pass
+
def visit_column(self, column):
- """visit a Column."""
+ """Visit a ``Column``."""
pass
+
def visit_foreign_key(self, join):
- """visit a ForeignKey."""
+ """Visit a ``ForeignKey``."""
pass
+
def visit_index(self, index):
- """visit an Index."""
+ """Visit an ``Index``."""
pass
+
def visit_passive_default(self, default):
- """visit a passive default"""
+ """Visit a passive default."""
pass
+
def visit_column_default(self, default):
- """visit a ColumnDefault."""
+ """Visit a ``ColumnDefault``."""
pass
+
def visit_column_onupdate(self, onupdate):
- """visit a ColumnDefault with the "for_update" flag set."""
+ """Visit a ``ColumnDefault`` with the `for_update` flag set."""
pass
+
def visit_sequence(self, sequence):
- """visit a Sequence."""
+ """Visit a ``Sequence``."""
pass
+
def visit_primary_key_constraint(self, constraint):
+ """Visit a ``PrimaryKeyConstraint``."""
pass
+
def visit_foreign_key_constraint(self, constraint):
+ """Visit a ``ForeignKeyConstraint``."""
pass
+
def visit_unique_constraint(self, constraint):
+ """Visit a ``UniqueConstraint``."""
pass
+
def visit_check_constraint(self, constraint):
+ """Visit a ``CheckConstraint``."""
pass
+
def visit_column_check_constraint(self, constraint):
+ """Visit a ``CheckConstraint`` on a ``Column``."""
pass
-
-default_metadata = DynamicMetaData('default')
-
+default_metadata = DynamicMetaData('default')
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index ffb4dc7510..d41e16bcbf 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -3,152 +3,245 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""defines the base components of SQL expression trees."""
+"""Define the base components of SQL expression trees."""
from sqlalchemy import util, exceptions
from sqlalchemy import types as sqltypes
import string, re, random, sets
-__all__ = ['text', 'table', 'column', 'literal_column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'except_', 'except_all', 'intersect', 'intersect_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists', 'extract','AbstractDialect', 'ClauseParameters', 'ClauseVisitor', 'Executor', 'Compiled', 'ClauseElement', 'ColumnElement', 'ColumnCollection', 'FromClause', 'TableClause', 'Select', 'Alias', 'CompoundSelect','Join', 'Selectable']
+__all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
+ 'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
+ 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
+ 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
+ 'between_', 'bindparam', 'case', 'cast', 'column', 'delete',
+ 'desc', 'except_', 'except_all', 'exists', 'extract', 'func',
+ 'insert', 'intersect', 'intersect_all', 'join', 'literal',
+ 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
+ 'subquery', 'table', 'text', 'union', 'union_all', 'update',]
def desc(column):
- """return a descending ORDER BY clause element, e.g.:
-
- order_by = [desc(table1.mycol)]
- """
+ """Return a descending ``ORDER BY`` clause element.
+
+ E.g.::
+
+ order_by = [desc(table1.mycol)]
+ """
return _CompoundClause(None, column, "DESC")
def asc(column):
- """return an ascending ORDER BY clause element, e.g.:
-
- order_by = [asc(table1.mycol)]
+ """Return an ascending ``ORDER BY`` clause element.
+
+ E.g.::
+
+ order_by = [asc(table1.mycol)]
"""
return _CompoundClause(None, column, "ASC")
def outerjoin(left, right, onclause=None, **kwargs):
- """return an OUTER JOIN clause element.
-
- left - the left side of the join
- right - the right side of the join
- onclause - optional criterion for the ON clause,
- is derived from foreign key relationships otherwise
-
+ """Return an ``OUTER JOIN`` clause element.
+
+ left
+ The left side of the join.
+
+ right
+ The right side of the join.
+
+ onclause
+ Optional criterion for the ``ON`` clause, is derived from
+ foreign key relationships otherwise.
+
To chain joins together, use the resulting
- Join object's "join()" or "outerjoin()" methods."""
+ ``Join`` object's ``join()`` or ``outerjoin()`` methods.
+ """
+
return Join(left, right, onclause, isouter = True, **kwargs)
def join(left, right, onclause=None, **kwargs):
- """return a JOIN clause element (regular inner join).
-
- left - the left side of the join
- right - the right side of the join
- onclause - optional criterion for the ON clause,
- is derived from foreign key relationships otherwise
+ """Return a ``JOIN`` clause element (regular inner join).
+
+ left
+ The left side of the join.
+
+ right
+ The right side of the join.
+
+ onclause
+ Optional criterion for the ``ON`` clause, is derived from
+ foreign key relationships otherwise
+
+ To chain joins together, use the resulting ``Join`` object's
+ ``join()`` or ``outerjoin()`` methods.
+ """
- To chain joins together, use the resulting Join object's
- "join()" or "outerjoin()" methods."""
return Join(left, right, onclause, **kwargs)
def select(columns=None, whereclause = None, from_obj = [], **kwargs):
- """returns a SELECT clause element.
-
- this can also be called via the table's select() method.
-
- 'columns' is a list of columns and/or selectable items to select columns from
- 'whereclause' is a text or ClauseElement expression which will form the WHERE clause
- 'from_obj' is an list of additional "FROM" objects, such as Join objects, which will
- extend or override the default "from" objects created from the column list and the
- whereclause.
- **kwargs - additional parameters for the Select object.
+ """Returns a ``SELECT`` clause element.
+
+ This can also be called via the table's ``select()`` method.
+
+ columns
+ A list of columns and/or selectable items to select columns from
+ `whereclause` is a text or ``ClauseElement`` expression which
+ will form the ``WHERE`` clause.
+
+ from_obj
+ A list of additional ``FROM`` objects, such as ``Join`` objects,
+ which will extend or override the default ``FROM`` objects
+ created from the column list and the whereclause.
+
+ **kwargs
+ Additional parameters for the ``Select`` object.
"""
+
return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
def subquery(alias, *args, **kwargs):
return Select(*args, **kwargs).alias(alias)
-
def insert(table, values = None, **kwargs):
- """returns an INSERT clause element.
-
- This can also be called from a table directly via the table's insert() method.
-
- 'table' is the table to be inserted into.
-
- 'values' is a dictionary which specifies the column specifications of the INSERT,
- and is optional. If left as None, the column specifications are determined from the
- bind parameters used during the compile phase of the INSERT statement. If the
- bind parameters also are None during the compile phase, then the column
- specifications will be generated from the full list of table columns.
-
- If both 'values' and compile-time bind parameters are present, the compile-time
- bind parameters override the information specified within 'values' on a per-key basis.
-
- The keys within 'values' can be either Column objects or their string identifiers.
- Each key may reference one of: a literal data value (i.e. string, number, etc.), a Column object,
- or a SELECT statement. If a SELECT statement is specified which references this INSERT
- statement's table, the statement will be correlated against the INSERT statement.
+ """Return an ``INSERT`` clause element.
+
+ This can also be called from a table directly via the table's
+ ``insert()`` method.
+
+ table
+ The table to be inserted into.
+
+ values
+ A dictionary which specifies the column specifications of the
+ ``INSERT``, and is optional. If left as None, the column
+ specifications are determined from the bind parameters used
+ during the compile phase of the ``INSERT`` statement. If the
+ bind parameters also are None during the compile phase, then the
+ column specifications will be generated from the full list of
+ table columns.
+
+ If both `values` and compile-time bind parameters are present, the
+ compile-time bind parameters override the information specified
+ within `values` on a per-key basis.
+
+ The keys within `values` can be either ``Column`` objects or their
+ string identifiers. Each key may reference one of:
+
+ * a literal data value (i.e. string, number, etc.);
+ * a Column object;
+ * a SELECT statement.
+
+ If a ``SELECT`` statement is specified which references this
+ ``INSERT`` statement's table, the statement will be correlated
+ against the ``INSERT`` statement.
"""
+
return _Insert(table, values, **kwargs)
def update(table, whereclause = None, values = None, **kwargs):
- """returns an UPDATE clause element.
-
- This can also be called from a table directly via the table's update() method.
-
- 'table' is the table to be updated.
- 'whereclause' is a ClauseElement describing the WHERE condition of the UPDATE statement.
- 'values' is a dictionary which specifies the SET conditions of the UPDATE, and is
- optional. If left as None, the SET conditions are determined from the bind parameters
- used during the compile phase of the UPDATE statement. If the bind parameters also are
- None during the compile phase, then the SET conditions will be generated from the full
- list of table columns.
-
- If both 'values' and compile-time bind parameters are present, the compile-time bind
- parameters override the information specified within 'values' on a per-key basis.
-
- The keys within 'values' can be either Column objects or their string identifiers. Each
- key may reference one of: a literal data value (i.e. string, number, etc.), a Column
- object, or a SELECT statement. If a SELECT statement is specified which references this
- UPDATE statement's table, the statement will be correlated against the UPDATE statement.
+ """Return an ``UPDATE`` clause element.
+
+ This can also be called from a table directly via the table's
+ ``update()`` method.
+
+ table
+ The table to be updated.
+
+ whereclause
+ A ``ClauseElement`` describing the ``WHERE`` condition of the
+ ``UPDATE`` statement.
+
+ values
+ A dictionary which specifies the ``SET`` conditions of the
+ ``UPDATE``, and is optional. If left as None, the ``SET``
+ conditions are determined from the bind parameters used during
+ the compile phase of the ``UPDATE`` statement. If the bind
+ parameters also are None during the compile phase, then the
+ ``SET`` conditions will be generated from the full list of table
+ columns.
+
+ If both `values` and compile-time bind parameters are present, the
+ compile-time bind parameters override the information specified
+ within `values` on a per-key basis.
+
+ The keys within `values` can be either ``Column`` objects or their
+ string identifiers. Each key may reference one of:
+
+ * a literal data value (i.e. string, number, etc.);
+ * a Column object;
+ * a SELECT statement.
+
+ If a ``SELECT`` statement is specified which references this
+ ``UPDATE`` statement's table, the statement will be correlated
+ against the ``UPDATE`` statement.
"""
+
return _Update(table, whereclause, values, **kwargs)
def delete(table, whereclause = None, **kwargs):
- """returns a DELETE clause element.
-
- This can also be called from a table directly via the table's delete() method.
-
- 'table' is the table to be updated.
- 'whereclause' is a ClauseElement describing the WHERE condition of the UPDATE statement.
+ """Return a ``DELETE`` clause element.
+
+ This can also be called from a table directly via the table's
+ ``delete()`` method.
+
+ table
+ The table to be updated.
+
+ whereclause
+ A ``ClauseElement`` describing the ``WHERE`` condition of the
+ ``UPDATE`` statement.
"""
+
return _Delete(table, whereclause, **kwargs)
def and_(*clauses):
- """joins a list of clauses together by the AND operator. the & operator can be used as well."""
+ """Join a list of clauses together by the ``AND`` operator.
+
+ The ``&`` operator can be used as well.
+ """
+
return _compound_clause('AND', *clauses)
def or_(*clauses):
- """joins a list of clauses together by the OR operator. the | operator can be used as well."""
+ """Join a list of clauses together by the ``OR`` operator.
+
+ The ``|`` operator can be used as well.
+ """
+
return _compound_clause('OR', *clauses)
def not_(clause):
- """returns a negation of the given clause, i.e. NOT(clause). the ~ operator can be used as well."""
+ """Return a negation of the given clause, i.e. ``NOT(clause)``.
+
+ The ``~`` operator can be used as well.
+ """
+
return clause._negate()
def between(ctest, cleft, cright):
- """ returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright).
-
- this is better called off a ColumnElement directly, i.e.
-
- column.between(value1, value2).
+ """Return ``BETWEEN`` predicate clause.
+
+ Equivalent of SQL ``clausetest BETWEEN clauseleft AND clauseright``.
+
+ This is better called off a ``ColumnElement`` directly, i.e.::
+
+ column.between(value1, value2)
"""
+
return _BooleanExpression(ctest, and_(_check_literal(cleft, ctest.type), _check_literal(cright, ctest.type)), 'BETWEEN')
between_ = between
def case(whens, value=None, else_=None):
- """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses;
- optional [value] for simple case statements, and [else_] for case defaults """
+ """``SQL CASE`` statement.
+
+ whens
+ A sequence of pairs to be translated into "when / then" clauses.
+
+ value
+ Optional for simple case statements.
+
+ else_
+ Optional as well, for case defaults.
+ """
+
whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
if else_:
whenlist.append(_CompoundClause(None, 'ELSE', else_))
@@ -203,79 +296,114 @@ def _check_literal(value, type):
return value
def literal(value, type=None):
- """returns a literal clause, bound to a bind parameter.
-
- literal clauses are created automatically when used as the right-hand
- side of a boolean or math operation against a column object. use this
- function when a literal is needed on the left-hand side (and optionally on the right as well).
-
- the optional type parameter is a sqlalchemy.types.TypeEngine object which indicates bind-parameter
- and result-set translation for this literal.
+ """Return a literal clause, bound to a bind parameter.
+
+ Literal clauses are created automatically when used as the
+ right-hand side of a boolean or math operation against a column
+ object. Use this function when a literal is needed on the
+ left-hand side (and optionally on the right as well).
+
+ The optional type parameter is a ``sqlalchemy.types.TypeEngine``
+ object which indicates bind-parameter and result-set translation
+ for this literal.
"""
+
return _BindParamClause('literal', value, type=type)
def label(name, obj):
- """returns a _Label object for the given selectable, used in the column list for a select statement."""
+ """Return a ``_Label`` object for the given selectable, used in
+ the column list for a select statement.
+ """
+
return _Label(name, obj)
-
+
def column(text, table=None, type=None, **kwargs):
- """return a textual column clause, relative to a table. this is also the primitive version of
- a schema.Column which is a subclass. """
+ """Return a textual column clause, relative to a table.
+
+ This is also the primitive version of a ``schema.Column`` which is
+ a subclass.
+ """
+
return _ColumnClause(text, table, type, **kwargs)
def literal_column(text, table=None, type=None, **kwargs):
- """return a textual column clause with the 'literal' flag set. this column will not be quoted"""
+ """Return a textual column clause with the `literal` flag set.
+
+ This column will not be quoted.
+ """
+
return _ColumnClause(text, table, type, is_literal=True, **kwargs)
-
+
def table(name, *columns):
- """returns a table clause. this is a primitive version of the schema.Table object, which is a subclass
- of this object."""
+ """Return a table clause.
+
+ This is a primitive version of the ``schema.Table`` object, which
+ is a subclass of this object.
+ """
+
return TableClause(name, *columns)
-
+
def bindparam(key, value=None, type=None, shortname=None):
- """creates a bind parameter clause with the given key.
-
- An optional default value can be specified by the value parameter, and the optional type parameter
- is a sqlalchemy.types.TypeEngine object which indicates bind-parameter and result-set translation for
- this bind parameter."""
+ """Create a bind parameter clause with the given key.
+
+ An optional default value can be specified by the value parameter,
+ and the optional type parameter is a
+ ``sqlalchemy.types.TypeEngine`` object which indicates
+ bind-parameter and result-set translation for this bind parameter.
+ """
+
if isinstance(key, _ColumnClause):
return _BindParamClause(key.name, value, type=key.type, shortname=shortname)
else:
return _BindParamClause(key, value, type=type, shortname=shortname)
def text(text, engine=None, *args, **kwargs):
- """creates literal text to be inserted into a query.
-
- When constructing a query from a select(), update(), insert() or delete(), using
- plain strings for argument values will usually result in text objects being created
- automatically. Use this function when creating textual clauses outside of other
- ClauseElement objects, or optionally wherever plain text is to be used.
-
- Arguments include:
+ """Create literal text to be inserted into a query.
- text - the text of the SQL statement to be created. use : to specify
- bind parameters; they will be compiled to their engine-specific format.
+ When constructing a query from a ``select()``, ``update()``,
+ ``insert()`` or ``delete()``, using plain strings for argument
+ values will usually result in text objects being created
+ automatically. Use this function when creating textual clauses
+ outside of other ``ClauseElement`` objects, or optionally wherever
+ plain text is to be used.
- engine - an optional engine to be used for this text query.
+ Arguments include:
- bindparams - a list of bindparam() instances which can be used to define the
- types and/or initial values for the bind parameters within the textual statement;
- the keynames of the bindparams must match those within the text of the statement.
- The types will be used for pre-processing on bind values.
+ text
+ The text of the SQL statement to be created. use ``:``
+ to specify bind parameters; they will be compiled to their
+ engine-specific format.
+
+ engine
+ An optional engine to be used for this text query.
+
+ bindparams
+ A list of ``bindparam()`` instances which can be used to define
+ the types and/or initial values for the bind parameters within
+ the textual statement; the keynames of the bindparams must match
+ those within the text of the statement. The types will be used
+ for pre-processing on bind values.
+
+ typemap
+ A dictionary mapping the names of columns represented in the
+ ``SELECT`` clause of the textual statement to type objects,
+ which will be used to perform post-processing on columns within
+ the result set (for textual statements that produce result
+ sets).
+ """
- typemap - a dictionary mapping the names of columns represented in the SELECT
- clause of the textual statement to type objects, which will be used to perform
- post-processing on columns within the result set (for textual statements that
- produce result sets)."""
return _TextClause(text, engine=engine, *args, **kwargs)
def null():
- """returns a Null object, which compiles to NULL in a sql statement."""
+ """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
+
return _Null()
class _FunctionGateway(object):
- """returns a callable based on an attribute name, which then returns a _Function
- object with that name."""
+ """Return a callable based on an attribute name, which then
+ returns a ``_Function`` object with that name.
+ """
+
def __getattr__(self, name):
if name[-1] == '_':
name = name[0:-1]
@@ -296,45 +424,58 @@ def is_column(col):
class AbstractDialect(object):
- """represents the behavior of a particular database. Used by Compiled objects."""
+ """Represent the behavior of a particular database.
+
+ Used by ``Compiled`` objects."""
pass
-
+
class ClauseParameters(dict):
- """represents a dictionary/iterator of bind parameter key names/values.
-
- Tracks the original BindParam objects as well as the keys/position of each
- parameter, and can return parameters as a dictionary or a list.
- Will process parameter values according to the TypeEngine objects present in
- the BindParams.
+ """Represent a dictionary/iterator of bind parameter key names/values.
+
+ Tracks the original ``BindParam`` objects as well as the
+ keys/position of each parameter, and can return parameters as a
+ dictionary or a list. Will process parameter values according to
+ the ``TypeEngine`` objects present in the ``BindParams``.
"""
+
def __init__(self, dialect, positional=None):
super(ClauseParameters, self).__init__(self)
self.dialect=dialect
self.binds = {}
self.positional = positional or []
+
def set_parameter(self, bindparam, value):
self[bindparam.key] = value
self.binds[bindparam.key] = bindparam
+
def get_original(self, key):
- """returns the given parameter as it was originally placed in this ClauseParameters object, without any Type conversion"""
+ """Return the given parameter as it was originally placed in
+ this ``ClauseParameters`` object, without any ``Type``
+ conversion."""
+
return super(ClauseParameters, self).__getitem__(key)
+
def __getitem__(self, key):
v = super(ClauseParameters, self).__getitem__(key)
if self.binds.has_key(key):
v = self.binds[key].typeprocess(v, self.dialect)
return v
+
def get_original_dict(self):
return self.copy()
+
def get_raw_list(self):
return [self[key] for key in self.positional]
+
def get_raw_dict(self):
d = {}
for k in self:
d[k] = self[k]
return d
-
+
class ClauseVisitor(object):
- """Defines the visiting of ClauseElements."""
+ """Define the visiting of ``ClauseElements``."""
+
def visit_column(self, column):pass
def visit_table(self, column):pass
def visit_fromclause(self, fromclause):pass
@@ -355,110 +496,161 @@ class ClauseVisitor(object):
def visit_typeclause(self, typeclause):pass
class Executor(object):
- """represents a 'thing that can produce Compiled objects and execute them'."""
+ """Represent a *thing that can produce Compiled objects and execute them*."""
+
def execute_compiled(self, compiled, parameters, echo=None, **kwargs):
- """execute a Compiled object."""
+ """Execute a Compiled object."""
+
raise NotImplementedError()
+
def compiler(self, statement, parameters, **kwargs):
- """return a Compiled object for the given statement and parameters."""
+ """Return a Compiled object for the given statement and parameters."""
+
raise NotImplementedError()
-
+
class Compiled(ClauseVisitor):
- """represents a compiled SQL expression. the __str__ method of the Compiled object
- should produce the actual text of the statement. Compiled objects are specific to the
- database library that created them, and also may or may not be specific to the columns
- referenced within a particular set of bind parameters. In no case should the Compiled
- object be dependent on the actual values of those bind parameters, even though it may
- reference those values as defaults."""
+ """Represent a compiled SQL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to the database library that created them, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
def __init__(self, dialect, statement, parameters, engine=None):
- """construct a new Compiled object.
-
- statement - ClauseElement to be compiled
-
- parameters - optional dictionary indicating a set of bind parameters
- specified with this Compiled object. These parameters are the "default"
- values corresponding to the ClauseElement's _BindParamClauses when the Compiled
- is executed. In the case of an INSERT or UPDATE statement, these parameters
- will also result in the creation of new _BindParamClause objects for each key
- and will also affect the generated column list in an INSERT statement and the SET
- clauses of an UPDATE statement. The keys of the parameter dictionary can
- either be the string names of columns or _ColumnClause objects.
-
- engine - optional Engine to compile this statement against"""
+ """Construct a new Compiled object.
+
+ statement
+ ``ClauseElement`` to be compiled.
+
+ parameters
+ Optional dictionary indicating a set of bind parameters
+ specified with this ``Compiled`` object. These parameters
+ are the *default* values corresponding to the
+ ``ClauseElement``'s ``_BindParamClauses`` when the
+ ``Compiled`` is executed. In the case of an ``INSERT`` or
+ ``UPDATE`` statement, these parameters will also result in
+ the creation of new ``_BindParamClause`` objects for each
+ key and will also affect the generated column list in an
+ ``INSERT`` statement and the ``SET`` clauses of an
+ ``UPDATE`` statement. The keys of the parameter dictionary
+ can either be the string names of columns or
+ ``_ColumnClause`` objects.
+
+ engine
+ Optional Engine to compile this statement against.
+ """
+
self.dialect = dialect
self.statement = statement
self.parameters = parameters
self.engine = engine
self.can_execute = statement.supports_execution()
-
+
def compile(self):
self.statement.accept_visitor(self)
self.after_compile()
-
+
def __str__(self):
- """returns the string text of the generated SQL statement."""
+ """Return the string text of the generated SQL statement."""
+
raise NotImplementedError()
+
def get_params(self, **params):
- """returns the bind params for this compiled object.
-
- Will start with the default parameters specified when this Compiled object
- was first constructed, and will override those values with those sent via
- **params, which are key/value pairs. Each key should match one of the
- _BindParamClause objects compiled into this object; either the "key" or
- "shortname" property of the _BindParamClause.
+ """Return the bind params for this compiled object.
+
+ Will start with the default parameters specified when this
+ ``Compiled`` object was first constructed, and will override
+ those values with those sent via `**params`, which are
+ key/value pairs. Each key should match one of the
+ ``_BindParamClause`` objects compiled into this object; either
+ the `key` or `shortname` property of the ``_BindParamClause``.
"""
+
raise NotImplementedError()
def execute(self, *multiparams, **params):
- """execute this compiled object."""
+ """Execute this compiled object."""
+
e = self.engine
if e is None:
raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.")
return e.execute_compiled(self, *multiparams, **params)
def scalar(self, *multiparams, **params):
- """execute this compiled object and return the result's scalar value."""
+ """Execute this compiled object and return the result's scalar value."""
+
return self.execute(*multiparams, **params).scalar()
-
+
class ClauseElement(object):
- """base class for elements of a programmatically constructed SQL expression."""
+ """Base class for elements of a programmatically constructed SQL
+ expression.
+ """
+
def _get_from_objects(self):
- """returns objects represented in this ClauseElement that should be added to the
- FROM list of a query, when this ClauseElement is placed in the column clause of a Select
- statement."""
+ """Return objects represented in this ``ClauseElement`` that
+ should be added to the ``FROM`` list of a query, when this
+ ``ClauseElement`` is placed in the column clause of a
+ ``Select`` statement.
+ """
+
raise NotImplementedError(repr(self))
+
def _hide_froms(self):
- """return a list of FROM clause elements which this ClauseElement replaces."""
+ """Return a list of ``FROM`` clause elements which this
+ ``ClauseElement`` replaces.
+ """
+
return []
+
def compare(self, other):
- """compare this ClauseElement to the given ClauseElement.
-
- Subclasses should override the default behavior, which is a straight
- identity comparison."""
+ """Compare this ClauseElement to the given ClauseElement.
+
+ Subclasses should override the default behavior, which is a
+ straight identity comparison.
+ """
+
return self is other
-
+
def accept_visitor(self, visitor):
- """accept a ClauseVisitor and call the appropriate visit_xxx method."""
+ """Accept a ``ClauseVisitor`` and call the appropriate
+ ``visit_xxx`` method.
+ """
+
raise NotImplementedError(repr(self))
def supports_execution(self):
- """return True if this clause element represents a complete executable statement"""
+ """Return True if this clause element represents a complete
+ executable statement.
+ """
+
return False
-
+
def copy_container(self):
- """return a copy of this ClauseElement, iff this ClauseElement contains other ClauseElements.
-
- If this ClauseElement is not a container, it should return self. This is used to
- create copies of expression trees that still reference the same "leaf nodes". The
- new structure can then be restructured without affecting the original."""
+ """Return a copy of this ``ClauseElement``, if this
+ ``ClauseElement`` contains other ``ClauseElements`.
+
+ If this ``ClauseElement`` is not a container, it should return
+ self. This is used to create copies of expression trees that
+ still reference the same *leaf nodes*. The new structure can
+ then be restructured without affecting the original.
+ """
+
return self
def _find_engine(self):
- """default strategy for locating an engine within the clause element.
- relies upon a local engine property, or looks in the "from" objects which
- ultimately have to contain Tables or TableClauses. """
+ """Default strategy for locating an engine within the clause element.
+
+ Relies upon a local engine property, or looks in the *from*
+ objects which ultimately have to contain Tables or
+ TableClauses.
+ """
+
try:
if self._engine is not None:
return self._engine
@@ -468,15 +660,16 @@ class ClauseElement(object):
if f is self:
continue
engine = f.engine
- if engine is not None:
+ if engine is not None:
return engine
else:
return None
-
- engine = property(lambda s: s._find_engine(), doc="attempts to locate a Engine within this ClauseElement structure, or returns None if none found.")
+
+ engine = property(lambda s: s._find_engine(), doc="Attempts to locate a Engine within this ClauseElement structure, or returns None if none found.")
def execute(self, *multiparams, **params):
- """compile and execute this ClauseElement."""
+ """Compile and execute this ``ClauseElement``."""
+
if len(multiparams):
compile_params = multiparams[0]
else:
@@ -484,27 +677,38 @@ class ClauseElement(object):
return self.compile(engine=self.engine, parameters=compile_params).execute(*multiparams, **params)
def scalar(self, *multiparams, **params):
- """compile and execute this ClauseElement, returning the result's scalar representation."""
+ """Compile and execute this ``ClauseElement``, returning the
+ result's scalar representation.
+ """
+
return self.execute(*multiparams, **params).scalar()
def compile(self, engine=None, parameters=None, compiler=None, dialect=None):
- """compile this SQL expression.
-
- Uses the given Compiler, or the given AbstractDialect or Engine to create a Compiler. If no compiler
- arguments are given, tries to use the underlying Engine this ClauseElement is bound
- to to create a Compiler, if any. Finally, if there is no bound Engine, uses an ANSIDialect
- to create a default Compiler.
-
- bindparams is a dictionary representing the default bind parameters to be used with
- the statement. if the bindparams is a list, it is assumed to be a list of dictionaries
- and the first dictionary in the list is used with which to compile against.
- The bind parameters can in some cases determine the output of the compilation, such as for UPDATE
- and INSERT statements the bind parameters that are present determine the SET and VALUES clause of
- those statements.
+ """Compile this SQL expression.
+
+ Uses the given ``Compiler``, or the given ``AbstractDialect``
+ or ``Engine`` to create a ``Compiler``. If no `compiler`
+ arguments are given, tries to use the underlying ``Engine`` this
+ ``ClauseElement`` is bound to to create a ``Compiler``, if any.
+
+ Finally, if there is no bound ``Engine``, uses an
+ ``ANSIDialect`` to create a default ``Compiler``.
+
+ `parameters` is a dictionary representing the default bind
+ parameters to be used with the statement. If `parameters` is
+ a list, it is assumed to be a list of dictionaries and the
+ first dictionary in the list is used with which to compile
+ against.
+
+ The bind parameters can in some cases determine the output of
+ the compilation, such as for ``UPDATE`` and ``INSERT``
+ statements the bind parameters that are present determine the
+ ``SET`` and ``VALUES`` clause of those statements.
"""
+
if (isinstance(parameters, list) or isinstance(parameters, tuple)):
parameters = parameters[0]
-
+
if compiler is None:
if dialect is not None:
compiler = dialect.compiler(self, parameters)
@@ -512,7 +716,7 @@ class ClauseElement(object):
compiler = engine.compiler(self, parameters)
elif self.engine is not None:
compiler = self.engine.compiler(self, parameters)
-
+
if compiler is None:
import sqlalchemy.ansisql as ansisql
compiler = ansisql.ANSIDialect().compiler(self, parameters=parameters)
@@ -521,32 +725,44 @@ class ClauseElement(object):
def __str__(self):
return str(self.compile())
+
def __and__(self, other):
return and_(self, other)
+
def __or__(self, other):
return or_(self, other)
+
def __invert__(self):
return self._negate()
+
def _negate(self):
self.parens=True
return _BooleanExpression(_TextClause("NOT"), self, None)
class _CompareMixin(object):
- """defines comparison operations for ClauseElements."""
+ """Define comparison operations for ClauseElements."""
+
def __lt__(self, other):
return self._compare('<', other)
+
def __le__(self, other):
return self._compare('<=', other)
+
def __eq__(self, other):
return self._compare('=', other)
+
def __ne__(self, other):
return self._compare('!=', other)
+
def __gt__(self, other):
return self._compare('>', other)
+
def __ge__(self, other):
return self._compare('>=', other)
+
def like(self, other):
return self._compare('LIKE', other)
+
def in_(self, *other):
if len(other) == 0:
return self.__eq__(None)
@@ -556,43 +772,59 @@ class _CompareMixin(object):
return self._compare('IN', ClauseList(parens=True, *[self._bind_param(o) for o in other]), negate='NOT IN')
else:
# assume *other is a single select.
- # originally, this assumed possibly multiple selects and created a UNION,
+ # originally, this assumed possibly multiple selects and created a UNION,
# but we are now forcing explictness if a UNION is desired.
if len(other) > 1:
raise exceptions.InvalidRequestException("in() function accepts only multiple literal values, or a single selectable as an argument")
return self._compare('IN', other[0], negate='NOT IN')
+
def startswith(self, other):
return self._compare('LIKE', other + "%")
+
def endswith(self, other):
return self._compare('LIKE', "%" + other)
+
def label(self, name):
return _Label(name, self, self.type)
+
def distinct(self):
return _CompoundClause(None,"DISTINCT", self)
+
def between(self, cleft, cright):
return _BooleanExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN')
+
def op(self, operator):
return lambda other: self._operate(operator, other)
+
# and here come the math operators:
+
def __add__(self, other):
return self._operate('+', other)
+
def __sub__(self, other):
return self._operate('-', other)
+
def __mul__(self, other):
return self._operate('*', other)
+
def __div__(self, other):
return self._operate('/', other)
+
def __mod__(self, other):
- return self._operate('%', other)
+ return self._operate('%', other)
+
def __truediv__(self, other):
return self._operate('/', other)
+
def _bind_param(self, obj):
return _BindParamClause('literal', obj, shortname=None, type=self.type)
+
def _check_literal(self, other):
if _is_literal(other):
return self._bind_param(other)
else:
return other
+
def _compare(self, operator, obj, negate=None):
if obj is None or isinstance(obj, _Null):
if operator == '=':
@@ -605,44 +837,66 @@ class _CompareMixin(object):
obj = self._check_literal(obj)
return _BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate)
+
def _operate(self, operator, obj):
if _is_literal(obj):
obj = self._bind_param(obj)
return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
+
def _compare_self(self):
- """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to
- just return self"""
+ """Allow ``ColumnImpl`` to return its ``Column`` object for
+ usage in ``ClauseElements``, all others to just return self.
+ """
+
return self
+
def _compare_type(self, obj):
- """allows subclasses to override the type used in constructing _BinaryClause objects. Default return
- value is the type of the given object."""
+ """Allow subclasses to override the type used in constructing
+ ``_BinaryClause`` objects.
+
+ Default return value is the type of the given object.
+ """
+
return obj.type
-
+
class Selectable(ClauseElement):
- """represents a column list-holding object."""
+ """Represent a column list-holding object."""
def _selectable(self):
return self
+
def accept_visitor(self, visitor):
raise NotImplementedError(repr(self))
+
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
+
def _group_parenthesized(self):
- """indicates if this Selectable requires parenthesis when grouped into a compound
- statement"""
+ """Indicate if this ``Selectable`` requires parenthesis when
+ grouped into a compound statement.
+ """
+
return True
class ColumnElement(Selectable, _CompareMixin):
- """represents a column element within the list of a Selectable's columns.
- A ColumnElement can either be directly associated with a TableClause, or
- a free-standing textual column with no table, or is a "proxy" column, indicating
- it is placed on a Selectable such as an Alias or Select statement and ultimately corresponds
- to a TableClause-attached column (or in the case of a CompositeSelect, a proxy ColumnElement
- may correspond to several TableClause-attached columns)."""
-
- primary_key = property(lambda self:getattr(self, '_primary_key', False), doc="primary key flag. indicates if this Column represents part or whole of a primary key.")
- foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []), doc="foreign key accessor. points to a list of ForeignKey objects which represents a Foreign Key placed on this column's ultimate ancestor.")
- columns = property(lambda self:[self], doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.")
+ """Represent a column element within the list of a Selectable's columns.
+
+ A ``ColumnElement`` can either be directly associated with a
+ ``TableClause``, or a free-standing textual column with no table,
+ or is a *proxy* column, indicating it is placed on a
+ ``Selectable`` such as an ``Alias`` or ``Select`` statement and
+ ultimately corresponds to a ``TableClause``-attached column (or in
+ the case of a ``CompositeSelect``, a proxy ``ColumnElement`` may
+ correspond to several ``TableClause``-attached columns).
+ """
+
+ primary_key = property(lambda self:getattr(self, '_primary_key', False),
+ doc="Primary key flag. Indicates if this Column represents part or whole of a primary key.")
+ foreign_keys = property(lambda self:getattr(self, '_foreign_keys', []),
+ doc="Foreign key accessor. Points to a list of ForeignKey objects which represents a Foreign Key placed on this column's ultimate ancestor.")
+ columns = property(lambda self:[self],
+ doc="Columns accessor which just returns self, to provide compatibility with Selectable objects.")
+
def _one_fkey(self):
if len(self._foreign_keys):
return list(self._foreign_keys)[0]
@@ -656,23 +910,32 @@ class ColumnElement(Selectable, _CompareMixin):
except AttributeError:
self.__orig_set = util.Set([self])
return self.__orig_set
+
def _set_orig_set(self, s):
if len(s) == 0:
s.add(self)
self.__orig_set = s
- orig_set = property(_get_orig_set, _set_orig_set,doc="""a Set containing TableClause-bound, non-proxied ColumnElements for which this ColumnElement is a proxy. In all cases except for a column proxied from a Union (i.e. CompoundSelect), this set will be just one element.""")
+ orig_set = property(_get_orig_set, _set_orig_set,
+ doc="A Set containing TableClause-bound, non-proxied ColumnElements for which this ColumnElement is a proxy. In all cases except for a column proxied from a Union (i.e. CompoundSelect), this set will be just one element.")
def shares_lineage(self, othercolumn):
- """returns True if the given ColumnElement has a common ancestor to this ColumnElement."""
+ """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``."""
+
for c in self.orig_set:
if c in othercolumn.orig_set:
return True
else:
return False
+
def _make_proxy(self, selectable, name=None):
- """creates a new ColumnElement representing this ColumnElement as it appears in the select list
- of a descending selectable. The default implementation returns a _ColumnClause if a name is given,
- else just returns self."""
+ """Create a new ``ColumnElement`` representing this
+ ``ColumnElement`` as it appears in the select list of a
+ descending selectable.
+
+ The default implementation returns a ``_ColumnClause`` if a
+ name is given, else just returns self.
+ """
+
if name is not None:
co = _ColumnClause(name, selectable)
co.orig_set = self.orig_set
@@ -682,19 +945,26 @@ class ColumnElement(Selectable, _CompareMixin):
return self
class ColumnCollection(util.OrderedProperties):
- """an ordered dictionary that stores a list of ColumnElement instances.
-
- overrides the __eq__() method to produce SQL clauses between sets of
- correlated columns."""
+ """An ordered dictionary that stores a list of ColumnElement
+ instances.
+
+ Overrides the ``__eq__()`` method to produce SQL clauses between
+ sets of correlated columns.
+ """
+
def __init__(self, *cols):
super(ColumnCollection, self).__init__()
[self.add(c) for c in cols]
+
def add(self, column):
- """add a column to this collection.
-
- the key attribute of the column will be used as the hash key for this
- dictionary."""
+ """Add a column to this collection.
+
+ The key attribute of the column will be used as the hash key
+ for this dictionary.
+ """
+
self[column.key] = column
+
def __eq__(self, other):
l = []
for c in other:
@@ -702,48 +972,70 @@ class ColumnCollection(util.OrderedProperties):
if c.shares_lineage(local):
l.append(c==local)
return and_(*l)
+
def contains_column(self, col):
- # have to use a Set here, because it will compare the identity
+ # have to use a Set here, because it will compare the identity
# of the column, not just using "==" for comparison which will always return a
# "True" value (i.e. a BinaryClause...)
return col in util.Set(self)
-
+
class FromClause(Selectable):
- """represents an element that can be used within the FROM clause of a SELECT statement."""
+ """Represent an element that can be used within the ``FROM``
+ clause of a ``SELECT`` statement.
+ """
+
def __init__(self, name=None):
self.name = name
+
def _get_from_objects(self):
# this could also be [self], at the moment it doesnt matter to the Select object
return []
+
def default_order_by(self):
return [self.oid_column]
- def accept_visitor(self, visitor):
+
+ def accept_visitor(self, visitor):
visitor.visit_fromclause(self)
+
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
else:
col = list(self.columns)[0]
return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
+
def join(self, right, *args, **kwargs):
return Join(self, right, *args, **kwargs)
+
def outerjoin(self, right, *args, **kwargs):
return Join(self, right, isouter=True, *args, **kwargs)
+
def alias(self, name=None):
return Alias(self, name)
+
def named_with_column(self):
- """True if the name of this FromClause may be prepended to a column in a generated SQL statement"""
+ """True if the name of this FromClause may be prepended to a
+ column in a generated SQL statement.
+ """
+
return False
+
def _locate_oid_column(self):
- """subclasses override this to return an appropriate OID column"""
+ """Subclasses should override this to return an appropriate OID column."""
+
return None
+
def _get_oid_column(self):
if not hasattr(self, '_oid_column'):
self._oid_column = self._locate_oid_column()
return self._oid_column
+
def corresponding_column(self, column, raiseerr=True, keys_ok=False, require_exact=False):
- """given a ColumnElement, return the ColumnElement object from this
- Selectable which corresponds to that original Column via a proxy relationship."""
+ """Given a ``ColumnElement``, return the ``ColumnElement``
+ object from this ``Selectable`` which corresponds to that
+ original ``Column`` via a proxy relationship.
+ """
+
if require_exact:
if self.columns.get(column.name) is column:
return column
@@ -767,29 +1059,34 @@ class FromClause(Selectable):
return None
else:
raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(column.table), self.name))
-
+
def _get_exported_attribute(self, name):
try:
return getattr(self, name)
except AttributeError:
self._export_columns()
return getattr(self, name)
+
columns = property(lambda s:s._get_exported_attribute('_columns'))
c = property(lambda s:s._get_exported_attribute('_columns'))
primary_key = property(lambda s:s._get_exported_attribute('_primary_key'))
foreign_keys = property(lambda s:s._get_exported_attribute('_foreign_keys'))
- original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc="a dictionary mapping an original Table-bound column to a proxied column in this FromClause.")
+ original_columns = property(lambda s:s._get_exported_attribute('_orig_cols'), doc="A dictionary mapping an original Table-bound column to a proxied column in this FromClause.")
oid_column = property(_get_oid_column)
-
+
def _export_columns(self):
- """initialize column collections.
-
- the collections include the primary key, foreign keys, list of all columns, as well as
- the "_orig_cols" collection which is a dictionary used to match Table-bound columns
- to proxied columns in this FromClause. The columns in each collection are "proxied" from
- the columns returned by the _exportable_columns method, where a "proxied" column maintains
- most or all of the properties of its original column, except its parent Selectable is this FromClause.
+ """Initialize column collections.
+
+ The collections include the primary key, foreign keys, list of
+ all columns, as well as the *_orig_cols* collection which is a
+ dictionary used to match Table-bound columns to proxied
+ columns in this ``FromClause``. The columns in each
+ collection are *proxied* from the columns returned by the
+ _exportable_columns method, where a *proxied* column maintains
+ most or all of the properties of its original column, except
+ its parent ``Selectable`` is this ``FromClause``.
"""
+
if hasattr(self, '_columns'):
# TODO: put a mutex here ? this is a key place for threading probs
return
@@ -810,68 +1107,103 @@ class FromClause(Selectable):
if self.oid_column is not None:
for ci in self.oid_column.orig_set:
self._orig_cols[ci] = self.oid_column
+
def _exportable_columns(self):
return []
+
def _proxy_column(self, column):
return column._make_proxy(self)
-
+
class _BindParamClause(ClauseElement, _CompareMixin):
- """represents a bind parameter. public constructor is the bindparam() function."""
+ """Represent a bind parameter.
+
+ Public constructor is the ``bindparam()`` function.
+ """
+
def __init__(self, key, value, shortname=None, type=None):
- """construct a _BindParamClause.
-
- key - the key for this bind param. will be used in the generated SQL statement
- for dialects that use named parameters. this value may be modified when part of a
- compilation operation, if other _BindParamClause objects exist with the same key, or if
- its length is too long and truncation is required.
-
- value - initial value for this bind param. This value may be overridden by the
- dictionary of parameters sent to statement compilation/execution.
-
- shortname - defaults to the key, a 'short name' that will also identify this
- bind parameter, similar to an alias. the bind parameter keys sent to a statement
- compilation or compiled execution may match either the key or the shortname of the
- corresponding _BindParamClause objects.
-
- type - a TypeEngine object that will be used to pre-process the value corresponding
- to this _BindParamClause at execution time."""
+ """Construct a _BindParamClause.
+
+ key
+ the key for this bind param. Will be used in the generated
+ SQL statement for dialects that use named parameters. This
+ value may be modified when part of a compilation operation,
+ if other ``_BindParamClause`` objects exist with the same
+ key, or if its length is too long and truncation is
+ required.
+
+ value
+ Initial value for this bind param. This value may be
+ overridden by the dictionary of parameters sent to statement
+ compilation/execution.
+
+ shortname
+ Defaults to the key, a *short name* that will also identify
+ this bind parameter, similar to an alias. the bind
+ parameter keys sent to a statement compilation or compiled
+ execution may match either the key or the shortname of the
+ corresponding ``_BindParamClause`` objects.
+
+ type
+
+ A ``TypeEngine`` object that will be used to pre-process the
+ value corresponding to this ``_BindParamClause`` at
+ execution time.
+ """
+
self.key = key
self.value = value
self.shortname = shortname or key
self.type = sqltypes.to_instance(type)
+
def accept_visitor(self, visitor):
visitor.visit_bindparam(self)
+
def _get_from_objects(self):
return []
+
def copy_container(self):
return _BindParamClause(self.key, self.value, self.shortname, self.type)
+
def typeprocess(self, value, dialect):
return self.type.dialect_impl(dialect).convert_bind_param(value, dialect)
+
def compare(self, other):
- """compares this _BindParamClause to the given clause.
-
- Since compare() is meant to compare statement syntax, this method
- returns True if the two _BindParamClauses have just the same type."""
+ """Compare this ``_BindParamClause`` to the given clause.
+
+ Since ``compare()`` is meant to compare statement syntax, this
+ method returns True if the two ``_BindParamClauses`` have just
+ the same type.
+ """
+
return isinstance(other, _BindParamClause) and other.type.__class__ == self.type.__class__
+
def _make_proxy(self, selectable, name = None):
return self
+
def __repr__(self):
return "_BindParamClause(%s, %s, type=%s)" % (repr(self.key), repr(self.value), repr(self.type))
-
+
class _TypeClause(ClauseElement):
- """handles a type keyword in a SQL statement. used by the Case statement."""
+ """Handle a type keyword in a SQL statement.
+
+ Used by the ``Case`` statement.
+ """
+
def __init__(self, type):
self.type = type
+
def accept_visitor(self, visitor):
visitor.visit_typeclause(self)
- def _get_from_objects(self):
+
+ def _get_from_objects(self):
return []
class _TextClause(ClauseElement):
- """represents literal a SQL text fragment. public constructor is the
- text() function.
-
+ """Represent a literal SQL text fragment.
+
+ Public constructor is the ``text()`` function.
"""
+
def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
self.parens = False
self._engine = engine
@@ -949,34 +1281,45 @@ class ClauseList(ClauseElement):
return False
class _CompoundClause(ClauseList):
- """represents a list of clauses joined by an operator, such as AND or OR.
- extends ClauseList to add the operator as well as a from_objects accessor to
- help determine FROM objects in a SELECT statement."""
+ """Represent a list of clauses joined by an operator, such as ``AND`` or ``OR``.
+
+ Extends ``ClauseList`` to add the operator as well as a
+ `from_objects` accessor to help determine ``FROM`` objects in a
+ ``SELECT`` statement.
+ """
+
def __init__(self, operator, *clauses, **kwargs):
ClauseList.__init__(self, *clauses, **kwargs)
self.operator = operator
+
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return _CompoundClause(self.operator, *clauses)
+
def append(self, clause):
if isinstance(clause, _CompoundClause):
clause.parens = True
ClauseList.append(self, clause)
+
def accept_visitor(self, visitor):
for c in self.clauses:
c.accept_visitor(visitor)
visitor.visit_compound(self)
+
def _get_from_objects(self):
f = []
for c in self.clauses:
f += c._get_from_objects()
return f
+
def compare(self, other):
- """compares this _CompoundClause to the given item.
-
- In addition to the regular comparison, has the special case that it
- returns True if this _CompoundClause has only one item, and that
- item matches the given item."""
+ """Compare this ``_CompoundClause`` to the given item.
+
+ In addition to the regular comparison, has the special case
+ that it returns True if this ``_CompoundClause`` has only one
+ item, and that item matches the given item.
+ """
+
if not isinstance(other, _CompoundClause):
if len(self.clauses) == 1:
return self.clauses[0].compare(other)
@@ -986,43 +1329,60 @@ class _CompoundClause(ClauseList):
return False
class _CalculatedClause(ClauseList, ColumnElement):
- """ describes a calculated SQL expression that has a type, like CASE. extends ColumnElement to
- provide column-level comparison operators. """
+ """Describe a calculated SQL expression that has a type, like ``CASE``.
+
+ Extends ``ColumnElement`` to provide column-level comparison
+ operators.
+ """
+
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type', None))
self._engine = kwargs.get('engine', None)
ClauseList.__init__(self, *clauses)
+
key = property(lambda self:self.name or "_calc_")
+
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
+
def accept_visitor(self, visitor):
for c in self.clauses:
c.accept_visitor(visitor)
visitor.visit_calculatedclause(self)
+
def _bind_param(self, obj):
return _BindParamClause(self.name, obj, type=self.type)
+
def select(self):
return select([self])
+
def scalar(self):
return select([self]).scalar()
+
def execute(self):
return select([self]).execute()
+
def _compare_type(self, obj):
return self.type
-
class _Function(_CalculatedClause, FromClause):
- """describes a SQL function. extends _CalculatedClause turn the "clauselist" into function
- arguments, also adds a "packagenames" argument"""
+ """Describe a SQL function.
+
+ Extends ``_CalculatedClause``, turn the *clauselist* into function
+ arguments, also adds a `packagenames` argument.
+ """
+
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
self._engine = kwargs.get('engine', None)
ClauseList.__init__(self, parens=True, *clauses)
+
key = property(lambda self:self.name)
+
def append(self, clause):
if _is_literal(clause):
if clause is None:
@@ -1030,9 +1390,11 @@ class _Function(_CalculatedClause, FromClause):
else:
clause = _BindParamClause(self.name, clause, shortname=self.name, type=None)
self.clauses.append(clause)
+
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
+
def accept_visitor(self, visitor):
for c in self.clauses:
c.accept_visitor(visitor)
@@ -1045,12 +1407,15 @@ class _Cast(ColumnElement):
self.type = sqltypes.to_instance(totype)
self.clause = clause
self.typeclause = _TypeClause(self.type)
+
def accept_visitor(self, visitor):
self.clause.accept_visitor(visitor)
self.typeclause.accept_visitor(visitor)
visitor.visit_cast(self)
+
def _get_from_objects(self):
return self.clause._get_from_objects()
+
def _make_proxy(self, selectable, name=None):
if name is not None:
co = _ColumnClause(name, selectable, type=self.type)
@@ -1059,21 +1424,25 @@ class _Cast(ColumnElement):
return co
else:
return self
-
+
class _FunctionGenerator(object):
- """generates _Function objects based on getattr calls"""
+ """Generate ``_Function`` objects based on getattr calls."""
+
def __init__(self, engine=None):
self.__engine = engine
self.__names = []
+
def __getattr__(self, name):
self.__names.append(name)
return self
+
def __call__(self, *c, **kwargs):
kwargs.setdefault('engine', self.__engine)
- return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs)
-
+ return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs)
+
class _BinaryClause(ClauseElement):
- """represents two clauses with an operator in between"""
+ """Represent two clauses with an operator in between."""
+
def __init__(self, left, right, operator, type=None):
self.left = left
self.right = right
@@ -1139,11 +1508,13 @@ class Join(FromClause):
self.isouter = isouter
name = property(lambda s: "Join object on " + s.left.name + " " + s.right.name)
+
def _locate_oid_column(self):
return self.left.oid_column
-
+
def _exportable_columns(self):
return [c for c in self.left.columns] + [c for c in self.right.columns]
+
def _proxy_column(self, column):
self._columns[column._label] = column
if column.primary_key:
@@ -1151,6 +1522,7 @@ class Join(FromClause):
for f in column.foreign_keys:
self._foreign_keys.add(f)
return column
+
def _match_primaries(self, primary, secondary):
crit = []
constraints = util.Set()
@@ -1173,11 +1545,13 @@ class Join(FromClause):
return (crit[0])
else:
return and_(*crit)
-
+
def _group_parenthesized(self):
return True
+
def select(self, whereclauses = None, **params):
return select([self.left, self.right], whereclauses, from_obj=[self], **params)
+
def accept_visitor(self, visitor):
self.left.accept_visitor(visitor)
self.right.accept_visitor(visitor)
@@ -1187,15 +1561,19 @@ class Join(FromClause):
engine = property(lambda s:s.left.engine or s.right.engine)
def alias(self, name=None):
- """creates a Select out of this Join clause and returns an Alias of it. The Select is not correlating."""
- return self.select(use_labels=True, correlate=False).alias(name)
+ """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it.
+
+ The ``Select`` is not correlating.
+ """
+
+ return self.select(use_labels=True, correlate=False).alias(name)
def _hide_froms(self):
return self.left._get_from_objects() + self.right._get_from_objects()
-
+
def _get_from_objects(self):
return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
-
+
class Alias(FromClause):
def __init__(self, selectable, alias = None):
baseselectable = selectable
@@ -1213,8 +1591,10 @@ class Alias(FromClause):
alias = alias + "_" + hex(random.randint(0, 65535))[2:]
self.name = alias
self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
+
def supports_execution(self):
- return self.original.supports_execution()
+ return self.original.supports_execution()
+
def _locate_oid_column(self):
if self.selectable.oid_column is not None:
return self.selectable.oid_column._make_proxy(self)
@@ -1223,6 +1603,7 @@ class Alias(FromClause):
def named_with_column(self):
return True
+
def _exportable_columns(self):
#return self.selectable._exportable_columns()
return self.selectable.columns
@@ -1236,10 +1617,9 @@ class Alias(FromClause):
def _group_parenthesized(self):
return False
-
+
engine = property(lambda s: s.selectable.engine)
-
class _Label(ColumnElement):
def __init__(self, name, obj, type=None):
self.name = name
@@ -1249,21 +1629,29 @@ class _Label(ColumnElement):
self.case_sensitive = getattr(obj, "case_sensitive", True)
self.type = sqltypes.to_instance(type)
obj.parens=True
+
key = property(lambda s: s.name)
_label = property(lambda s: s.name)
orig_set = property(lambda s:s.obj.orig_set)
+
def accept_visitor(self, visitor):
self.obj.accept_visitor(visitor)
visitor.visit_label(self)
+
def _get_from_objects(self):
return self.obj._get_from_objects()
+
def _make_proxy(self, selectable, name = None):
return self.obj._make_proxy(selectable, name=self.name)
-legal_characters = util.Set(string.ascii_letters + string.digits + '_')
+legal_characters = util.Set(string.ascii_letters + string.digits + '_')
+
class _ColumnClause(ColumnElement):
- """represents a textual column clause in a SQL statement. May or may not
- be bound to an underlying Selectable."""
+ """Represent a textual column clause in a SQL statement.
+
+ May or may not be bound to an underlying ``Selectable``.
+ """
+
def __init__(self, text, selectable=None, type=None, _is_oid=False, case_sensitive=True, is_literal=False):
self.key = self.name = text
self.table = selectable
@@ -1272,6 +1660,7 @@ class _ColumnClause(ColumnElement):
self.__label = None
self.case_sensitive = case_sensitive
self.is_literal = is_literal
+
def _get_label(self):
if self.__label is None:
if self.table is not None and self.table.named_with_column():
@@ -1282,30 +1671,41 @@ class _ColumnClause(ColumnElement):
self.__label = self.name
self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
+
_label = property(_get_label)
- def accept_visitor(self, visitor):
+
+ def accept_visitor(self, visitor):
visitor.visit_column(self)
+
def to_selectable(self, selectable):
- """given a Selectable, returns this column's equivalent in that Selectable, if any.
-
- for example, this could translate the column "name" from a Table object
- to an Alias of a Select off of that Table object."""
+ """Given a ``Selectable``, return this column's equivalent in
+ that ``Selectable``, if any.
+
+ For example, this could translate the column *name* from a
+ ``Table`` object to an ``Alias`` of a ``Select`` off of that
+ ``Table`` object."""
+
return selectable.corresponding_column(self.original, False)
+
def _get_from_objects(self):
if self.table is not None:
return [self.table]
else:
return []
+
def _bind_param(self, obj):
return _BindParamClause(self._label, obj, shortname = self.name, type=self.type)
+
def _make_proxy(self, selectable, name = None):
c = _ColumnClause(name or self.name, selectable, _is_oid=self._is_oid, type=self.type)
c.orig_set = self.orig_set
if not self._is_oid:
selectable.columns[c.name] = c
return c
+
def _compare_type(self, obj):
return self.type
+
def _group_parenthesized(self):
return False
@@ -1322,11 +1722,14 @@ class TableClause(FromClause):
def named_with_column(self):
return True
+
def append_column(self, c):
self._columns[c.name] = c
c.table = self
+
def _locate_oid_column(self):
return self._oid_column
+
def _orig_columns(self):
try:
return self._orig_cols
@@ -1336,41 +1739,55 @@ class TableClause(FromClause):
for ci in c.orig_set:
self._orig_cols[ci] = c
return self._orig_cols
+
original_columns = property(_orig_columns)
def accept_visitor(self, visitor):
visitor.visit_table(self)
+
def _exportable_columns(self):
raise NotImplementedError()
+
def _group_parenthesized(self):
return False
+
def count(self, whereclause=None, **params):
if len(self.primary_key):
col = list(self.primary_key)[0]
else:
col = list(self.columns)[0]
return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
+
def join(self, right, *args, **kwargs):
return Join(self, right, *args, **kwargs)
+
def outerjoin(self, right, *args, **kwargs):
return Join(self, right, isouter = True, *args, **kwargs)
+
def alias(self, name=None):
return Alias(self, name)
+
def select(self, whereclause = None, **params):
return select([self], whereclause, **params)
+
def insert(self, values = None):
return insert(self, values=values)
+
def update(self, whereclause = None, values = None):
return update(self, whereclause, values)
+
def delete(self, whereclause = None):
return delete(self, whereclause)
+
def _get_from_objects(self):
return [self]
class _SelectBaseMixin(object):
- """base class for Select and CompoundSelects"""
+ """Base class for ``Select`` and ``CompoundSelects``."""
+
def supports_execution(self):
return True
+
def order_by(self, *clauses):
if len(clauses) == 1 and clauses[0] is None:
self.order_by_clause = ClauseList()
@@ -1378,6 +1795,7 @@ class _SelectBaseMixin(object):
self.order_by_clause = ClauseList(*(list(self.order_by_clause.clauses) + list(clauses)))
else:
self.order_by_clause = ClauseList(*clauses)
+
def group_by(self, *clauses):
if len(clauses) == 1 and clauses[0] is None:
self.group_by_clause = ClauseList()
@@ -1385,14 +1803,16 @@ class _SelectBaseMixin(object):
self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
else:
self.group_by_clause = ClauseList(*clauses)
+
def select(self, whereclauses = None, **params):
return select([self], whereclauses, **params)
+
def _get_from_objects(self):
if self.is_where or self.is_scalar:
return []
else:
return [self]
-
+
class CompoundSelect(_SelectBaseMixin, FromClause):
def __init__(self, keyword, *selects, **kwargs):
_SelectBaseMixin.__init__(self)
@@ -1414,7 +1834,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
for s in selects:
s.group_by(None)
s.order_by(None)
-
+
self.group_by(*kwargs.pop('group_by', [None]))
self.order_by(*kwargs.pop('order_by', [None]))
if len(kwargs):
@@ -1422,13 +1842,15 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
self._col_map = {}
name = property(lambda s:s.keyword + " statement")
-
+
def _locate_oid_column(self):
return self.selects[0].oid_column
+
def _exportable_columns(self):
for s in self.selects:
for c in s.c:
yield c
+
def _proxy_column(self, column):
if self.use_labels:
col = column._make_proxy(self, name=column._label)
@@ -1442,13 +1864,14 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
[colset.add(c) for c in col.orig_set]
col.orig_set = colset
return col
-
+
def accept_visitor(self, visitor):
self.order_by_clause.accept_visitor(visitor)
self.group_by_clause.accept_visitor(visitor)
for s in self.selects:
s.accept_visitor(visitor)
visitor.visit_compound_select(self)
+
def _find_engine(self):
for s in self.selects:
e = s._find_engine()
@@ -1456,10 +1879,12 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
return e
else:
return None
-
+
class Select(_SelectBaseMixin, FromClause):
- """represents a SELECT statement, with appendable clauses, as well as
- the ability to execute itself and return a result set."""
+ """Represent a ``SELECT`` statement, with appendable clauses, as
+ well as the ability to execute itself and return a result set.
+ """
+
def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, for_update=False, engine=None, limit=None, offset=None, scalar=False, correlate=True):
_SelectBaseMixin.__init__(self)
self.__froms = util.OrderedSet()
@@ -1482,14 +1907,14 @@ class Select(_SelectBaseMixin, FromClause):
# its FROM clause to that of an enclosing select statement.
# note that the "correlate" method can be used to explicitly add a value to be correlated.
self.should_correlate = correlate
-
+
# indicates if this select statement is a subquery inside another query
self.is_subquery = False
-
+
# indicates if this select statement is a subquery as a criterion
# inside of a WHERE clause
self.is_where = False
-
+
self.distinct = distinct
self._raw_columns = []
self.__correlated = {}
@@ -1518,20 +1943,27 @@ class Select(_SelectBaseMixin, FromClause):
self.append_whereclause(whereclause)
if having is not None:
self.append_having(having)
-
-
+
+
class _CorrelatedVisitor(ClauseVisitor):
- """visits a clause, locates any Select clauses, and tells them that they should
- correlate their FROM list to that of their parent."""
+ """Visit a clause, locate any ``Select`` clauses, and tell
+ them that they should correlate their ``FROM`` list to that of
+ their parent.
+ """
+
def __init__(self, select, is_where):
self.select = select
self.is_where = is_where
+
def visit_compound_select(self, cs):
self.visit_select(cs)
for s in cs.selects:
s.parens = False
+
def visit_column(self, c):pass
+
def visit_table(self, c):pass
+
def visit_select(self, select):
if select is self.select:
return
@@ -1541,7 +1973,7 @@ class Select(_SelectBaseMixin, FromClause):
if not select.should_correlate:
return
[select.correlate(x) for x in self.select._Select__froms]
-
+
def append_column(self, column):
if _is_literal(column):
column = literal_column(str(column), table=self)
@@ -1586,11 +2018,13 @@ class Select(_SelectBaseMixin, FromClause):
self.__froms.add(elem)
for f in elem._hide_froms():
self.__hide_froms.add(f)
-
+
def append_whereclause(self, whereclause):
self._append_condition('whereclause', whereclause)
+
def append_having(self, having):
self._append_condition('having', having)
+
def _append_condition(self, attribute, condition):
if type(condition) == str:
condition = _TextClause(condition)
@@ -1600,25 +2034,27 @@ class Select(_SelectBaseMixin, FromClause):
setattr(self, attribute, and_(getattr(self, attribute), condition))
else:
setattr(self, attribute, condition)
-
+
def correlate(self, from_obj):
- """given a FROM object, correlate this SELECT statement to it.
-
- this basically means the given from object will not come out in this select statement's FROM
- clause when printed."""
+ """Given a ``FROM`` object, correlate this ``SELECT`` statement to it.
+
+ This basically means the given from object will not come out
+ in this select statement's ``FROM`` clause when printed.
+ """
+
self.__correlated[from_obj] = from_obj
-
+
def append_from(self, fromclause):
if type(fromclause) == str:
fromclause = FromClause(fromclause)
fromclause.accept_visitor(self.__correlator)
self._process_froms(fromclause, True)
-
+
def _locate_oid_column(self):
for f in self.__froms:
if f is self:
# we might be in our own _froms list if a column with us as the parent is attached,
- # which includes textual columns.
+ # which includes textual columns.
continue
oid = f.oid_column
if oid is not None:
@@ -1632,7 +2068,8 @@ class Select(_SelectBaseMixin, FromClause):
return f.difference(self.__correlated)
else:
return f
- froms = property(_calc_froms, doc="""a collection containing all elements of the FROM clause""")
+
+ froms = property(_calc_froms, doc="""A collection containing all elements of the FROM clause""")
def accept_visitor(self, visitor):
for f in self.froms:
@@ -1644,22 +2081,25 @@ class Select(_SelectBaseMixin, FromClause):
self.order_by_clause.accept_visitor(visitor)
self.group_by_clause.accept_visitor(visitor)
visitor.visit_select(self)
-
+
def union(self, other, **kwargs):
return union(self, other, **kwargs)
+
def union_all(self, other, **kwargs):
return union_all(self, other, **kwargs)
+
def _find_engine(self):
- """tries to return a Engine, either explicitly set in this object, or searched
- within the from clauses for one"""
-
+ """Try to return a Engine, either explicitly set in this
+ object, or searched within the from clauses for one.
+ """
+
if self._engine is not None:
return self._engine
for f in self.__froms:
if f is self:
continue
e = f.engine
- if e is not None:
+ if e is not None:
self._engine = e
return e
# look through the columns (largely synomous with looking
@@ -1675,12 +2115,16 @@ class Select(_SelectBaseMixin, FromClause):
return None
class _UpdateBase(ClauseElement):
- """forms the base for INSERT, UPDATE, and DELETE statements."""
+ """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
+
def supports_execution(self):
return True
+
def _process_colparams(self, parameters):
- """receives the "values" of an INSERT or UPDATE statement and constructs
- appropriate bind parameters."""
+ """Receive the *values* of an ``INSERT`` or ``UPDATE``
+ statement and construct appropriate bind parameters.
+ """
+
if parameters is None:
return None
@@ -1691,7 +2135,7 @@ class _UpdateBase(ClauseElement):
pp[c.key] = parameters[i]
i +=1
parameters = pp
-
+
for key in parameters.keys():
value = parameters[key]
if isinstance(value, Select):
@@ -1706,6 +2150,7 @@ class _UpdateBase(ClauseElement):
except KeyError:
del parameters[key]
return parameters
+
def _find_engine(self):
return self.table.engine
@@ -1714,7 +2159,7 @@ class _Insert(_UpdateBase):
self.table = table
self.select = None
self.parameters = self._process_colparams(values)
-
+
def accept_visitor(self, visitor):
if self.select is not None:
self.select.accept_visitor(visitor)
@@ -1741,4 +2186,3 @@ class _Delete(_UpdateBase):
if self.whereclause is not None:
self.whereclause.accept_visitor(visitor)
visitor.visit_delete(self)
-
diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py
index e672feb1a9..db3590cd8b 100644
--- a/lib/sqlalchemy/sql_util.py
+++ b/lib/sqlalchemy/sql_util.py
@@ -1,25 +1,31 @@
from sqlalchemy import sql, util, schema, topological
-"""utility functions that build upon SQL and Schema constructs"""
-
+"""Utility functions that build upon SQL and Schema constructs."""
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
+
def __len__(self):
return len(self.tables)
+
def __getitem__(self, i):
return self.tables[i]
+
def __iter__(self):
return iter(self.tables)
+
def __contains__(self, obj):
return obj in self.tables
+
def __add__(self, obj):
return self.tables + list(obj)
+
def add(self, table):
self.tables.append(table)
if hasattr(self, '_sorted'):
del self._sorted
+
def sort(self, reverse=False):
try:
sorted = self._sorted
@@ -32,7 +38,7 @@ class TableCollection(object):
return x
else:
return sorted
-
+
def _do_sort(self):
tuples = []
class TVisitor(schema.SchemaVisitor):
@@ -43,7 +49,7 @@ class TableCollection(object):
if parent_table in self:
child_table = fkey.parent.table
tuples.append( ( parent_table, child_table ) )
- vis = TVisitor()
+ vis = TVisitor()
for table in self.tables:
table.accept_schema_visitor(vis)
sorter = topological.QueueDependencySorter( tuples, self.tables )
@@ -56,17 +62,20 @@ class TableCollection(object):
if head is not None:
to_sequence( head )
return sequence
-
+
class TableFinder(TableCollection, sql.ClauseVisitor):
- """given a Clause, locates all the Tables within it into a list."""
+ """Given a ``Clause``, locate all the ``Tables`` within it into a list."""
+
def __init__(self, table, check_columns=False):
TableCollection.__init__(self)
self.check_columns = check_columns
if table is not None:
table.accept_visitor(self)
+
def visit_table(self, table):
self.tables.append(table)
+
def visit_column(self, column):
if self.check_columns:
column.table.accept_visitor(self)
@@ -74,48 +83,66 @@ class TableFinder(TableCollection, sql.ClauseVisitor):
class ColumnFinder(sql.ClauseVisitor):
def __init__(self):
self.columns = util.Set()
+
def visit_column(self, c):
self.columns.add(c)
+
def __iter__(self):
return iter(self.columns)
class ColumnsInClause(sql.ClauseVisitor):
- """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
+ """Given a selectable, visit clauses and determine if any columns
+ from the clause are in the selectable.
+ """
+
def __init__(self, selectable):
self.selectable = selectable
self.result = False
+
def visit_column(self, column):
if self.selectable.c.get(column.key) is column:
self.result = True
class AbstractClauseProcessor(sql.ClauseVisitor):
- """traverses a clause and attempts to convert the contents of container elements
- to a converted element. the conversion operation is defined by subclasses."""
+ """Traverse a clause and attempt to convert the contents of container elements
+ to a converted element.
+
+ The conversion operation is defined by subclasses.
+ """
+
def convert_element(self, elem):
- """define the 'conversion' method for this AbstractClauseProcessor"""
+ """Define the *conversion* method for this ``AbstractClauseProcessor``."""
+
raise NotImplementedError()
+
def copy_and_process(self, list_):
- """copy the container elements in the given list to a new list and
- process the new list."""
+ """Copy the container elements in the given list to a new list and
+ process the new list.
+ """
+
list_ = [o.copy_container() for o in list_]
self.process_list(list_)
return list_
def process_list(self, list_):
- """process all elements of the given list in-place"""
+ """Process all elements of the given list in-place."""
+
for i in range(0, len(list_)):
elem = self.convert_element(list_[i])
if elem is not None:
list_[i] = elem
else:
list_[i].accept_visitor(self)
+
def visit_compound(self, compound):
self.visit_clauselist(compound)
+
def visit_clauselist(self, clist):
for i in range(0, len(clist.clauses)):
n = self.convert_element(clist.clauses[i])
if n is not None:
clist.clauses[i] = n
+
def visit_binary(self, binary):
elem = self.convert_element(binary.left)
if elem is not None:
@@ -123,9 +150,10 @@ class AbstractClauseProcessor(sql.ClauseVisitor):
elem = self.convert_element(binary.right)
if elem is not None:
binary.right = elem
-
+
class Aliasizer(AbstractClauseProcessor):
- """converts a table instance within an expression to be an alias of that table."""
+ """Convert a table instance within an expression to be an alias of that table."""
+
def __init__(self, *tables, **kwargs):
self.tables = {}
self.aliases = kwargs.get('aliases', {})
@@ -138,8 +166,10 @@ class Aliasizer(AbstractClauseProcessor):
self.tables[t2.table] = t2
self.aliases[t2.table] = self.aliases[t]
self.binary = None
+
def get_alias(self, table):
return self.aliases[table]
+
def convert_element(self, elem):
if isinstance(elem, sql.ColumnElement) and hasattr(elem, 'table') and self.tables.has_key(elem.table):
return self.get_alias(elem.table).corresponding_column(elem)
@@ -147,37 +177,39 @@ class Aliasizer(AbstractClauseProcessor):
return None
class ClauseAdapter(AbstractClauseProcessor):
- """given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable,
- and changes those columns to be that of the selectable.
-
- such as:
-
- table1 = Table('sometable', metadata,
- Column('col1', Integer),
- Column('col2', Integer)
- )
- table2 = Table('someothertable', metadata,
- Column('col1', Integer),
- Column('col2', Integer)
- )
-
- condition = table1.c.col1 == table2.c.col1
-
- and make an alias of table1:
-
- s = table1.alias('foo')
-
- calling condition.accept_visitor(ClauseAdapter(s)) converts condition to read:
-
- s.c.col1 == table2.c.col1
-
+ """Given a clause (like as in a WHERE criterion), locate columns
+ which *correspond* to a given selectable, and changes those
+ columns to be that of the selectable.
+
+ E.g.::
+
+ table1 = Table('sometable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+ table2 = Table('someothertable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+
+ condition = table1.c.col1 == table2.c.col1
+
+ and make an alias of table1::
+
+ s = table1.alias('foo')
+
+ calling ``condition.accept_visitor(ClauseAdapter(s))`` converts
+ condition to read::
+
+ s.c.col1 == table2.c.col1
"""
+
def __init__(self, selectable, include=None, exclude=None, equivalents=None):
self.selectable = selectable
self.include = include
self.exclude = exclude
self.equivalents = equivalents
-
+
def convert_element(self, col):
if not isinstance(col, sql.ColumnElement):
return None
@@ -191,4 +223,3 @@ class ClauseAdapter(AbstractClauseProcessor):
if newcol is None and self.equivalents is not None and col in self.equivalents:
newcol = self.selectable.corresponding_column(self.equivalents[col], raiseerr=False, keys_ok=False)
return newcol
-
diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py
index a1b03b89a2..d9f68ac011 100644
--- a/lib/sqlalchemy/topological.py
+++ b/lib/sqlalchemy/topological.py
@@ -4,57 +4,78 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""topological sorting algorithms. the key to the unit of work is to assemble a list
-of dependencies amongst all the different mappers that have been defined for classes.
-Related tables with foreign key constraints have a definite insert order, deletion order,
-objects need dependent properties from parent objects set up before saved, etc.
-These are all encoded as dependencies, in the form "mapper X is dependent on mapper Y",
-meaning mapper Y's objects must be saved before those of mapper X, and mapper X's objects
-must be deleted before those of mapper Y.
-
-The topological sort is an algorithm that receives this list of dependencies as a "partial
-ordering", that is a list of pairs which might say, "X is dependent on Y", "Q is dependent
-on Z", but does not necessarily tell you anything about Q being dependent on X. Therefore,
-its not a straight sort where every element can be compared to another...only some of the
-elements have any sorting preference, and then only towards just some of the other elements.
-For a particular partial ordering, there can be many possible sorts that satisfy the
+"""Topological sorting algorithms.
+
+The key to the unit of work is to assemble a list of dependencies
+amongst all the different mappers that have been defined for classes.
+
+Related tables with foreign key constraints have a definite insert
+order, deletion order, objects need dependent properties from parent
+objects set up before saved, etc.
+
+These are all encoded as dependencies, in the form *mapper X is
+dependent on mapper Y*, meaning mapper Y's objects must be saved
+before those of mapper X, and mapper X's objects must be deleted
+before those of mapper Y.
+
+The topological sort is an algorithm that receives this list of
+dependencies as a *partial ordering*, that is a list of pairs which
+might say, *X is dependent on Y*, *Q is dependent on Z*, but does not
+necessarily tell you anything about Q being dependent on X. Therefore,
+its not a straight sort where every element can be compared to
+another... only some of the elements have any sorting preference, and
+then only towards just some of the other elements. For a particular
+partial ordering, there can be many possible sorts that satisfy the
conditions.
-An intrinsic "gotcha" to this algorithm is that since there are many possible outcomes
-to sorting a partial ordering, the algorithm can return any number of different results for the
-same input; just running it on a different machine architecture, or just random differences
-in the ordering of dictionaries, can change the result that is returned. While this result
-is guaranteed to be true to the incoming partial ordering, if the partial ordering itself
-does not properly represent the dependencies, code that works fine will suddenly break, then
-work again, then break, etc. Most of the bugs I've chased down while developing the "unit of work"
-have been of this nature - very tricky to reproduce and track down, particularly before I
-realized this characteristic of the algorithm.
+An intrinsic *gotcha* to this algorithm is that since there are many
+possible outcomes to sorting a partial ordering, the algorithm can
+return any number of different results for the same input; just
+running it on a different machine architecture, or just random
+differences in the ordering of dictionaries, can change the result
+that is returned. While this result is guaranteed to be true to the
+incoming partial ordering, if the partial ordering itself does not
+properly represent the dependencies, code that works fine will
+suddenly break, then work again, then break, etc. Most of the bugs
+I've chased down while developing the *unit of work* have been of this
+nature - very tricky to reproduce and track down, particularly before
+I realized this characteristic of the algorithm.
"""
+
import string, StringIO
from sqlalchemy import util
from sqlalchemy.exceptions import *
class _Node(object):
- """represents each item in the sort. While the topological sort
- produces a straight ordered list of items, _Node ultimately stores a tree-structure
- of those items which are organized so that non-dependent nodes are siblings."""
+ """Represent each item in the sort.
+
+ While the topological sort produces a straight ordered list of
+ items, ``_Node`` ultimately stores a tree-structure of those items
+ which are organized so that non-dependent nodes are siblings.
+ """
+
def __init__(self, item):
self.item = item
self.dependencies = util.Set()
self.children = []
self.cycles = None
+
def __str__(self):
return self.safestr()
+
def safestr(self, indent=0):
return (' ' * indent * 2) + \
str(self.item) + \
(self.cycles is not None and (" (cycles: " + repr([x for x in self.cycles]) + ")") or "") + \
"\n" + \
string.join([n.safestr(indent + 1) for n in self.children], '')
+
def __repr__(self):
return "%s" % (str(self.item))
+
def all_deps(self):
- """Returns a set of dependencies for this node and all its cycles"""
+ """Return a set of dependencies for this node and all its cycles."""
+
deps = util.Set(self.dependencies)
if self.cycles is not None:
for c in self.cycles:
@@ -62,12 +83,15 @@ class _Node(object):
return deps
class _EdgeCollection(object):
- """a collection of directed edges."""
+ """A collection of directed edges."""
+
def __init__(self):
self.parent_to_children = {}
self.child_to_parents = {}
+
def add(self, edge):
- """add an edge to this collection."""
+ """Add an edge to this collection."""
+
(parentnode, childnode) = edge
if not self.parent_to_children.has_key(parentnode):
self.parent_to_children[parentnode] = util.Set()
@@ -76,8 +100,13 @@ class _EdgeCollection(object):
self.child_to_parents[childnode] = util.Set()
self.child_to_parents[childnode].add(parentnode)
parentnode.dependencies.add(childnode)
+
def remove(self, edge):
- """remove an edge from this collection. return the childnode if it has no other parents"""
+ """Remove an edge from this collection.
+
+ Return the childnode if it has no other parents.
+ """
+
(parentnode, childnode) = edge
self.parent_to_children[parentnode].remove(childnode)
self.child_to_parents[childnode].remove(parentnode)
@@ -85,52 +114,65 @@ class _EdgeCollection(object):
return childnode
else:
return None
+
def has_parents(self, node):
return self.child_to_parents.has_key(node) and len(self.child_to_parents[node]) > 0
+
def edges_by_parent(self, node):
if self.parent_to_children.has_key(node):
return [(node, child) for child in self.parent_to_children[node]]
else:
return []
+
def get_parents(self):
return self.parent_to_children.keys()
+
def pop_node(self, node):
- """remove all edges where the given node is a parent.
-
- returns the collection of all nodes which were children of the given node, and have
- no further parents."""
+ """Remove all edges where the given node is a parent.
+
+ Return the collection of all nodes which were children of the
+ given node, and have no further parents.
+ """
+
children = self.parent_to_children.pop(node, None)
if children is not None:
for child in children:
self.child_to_parents[child].remove(node)
if not len(self.child_to_parents[child]):
yield child
+
def __len__(self):
return sum([len(x) for x in self.parent_to_children.values()])
+
def __iter__(self):
for parent, children in self.parent_to_children.iteritems():
for child in children:
yield (parent, child)
+
def __str__(self):
return repr(list(self))
+
def __repr__(self):
return repr(list(self))
-
+
class QueueDependencySorter(object):
- """topological sort adapted from wikipedia's article on the subject. it creates a straight-line
- list of elements, then a second pass groups non-dependent actions together to build
- more of a tree structure with siblings."""
-
+ """Topological sort adapted from wikipedia's article on the subject.
+
+ It creates a straight-line list of elements, then a second pass
+ groups non-dependent actions together to build more of a tree
+ structure with siblings.
+ """
+
def __init__(self, tuples, allitems):
self.tuples = tuples
self.allitems = allitems
def sort(self, allow_self_cycles=True, allow_all_cycles=False):
(tuples, allitems) = (self.tuples, self.allitems)
- #print "\n---------------------------------\n"
+ #print "\n---------------------------------\n"
#print repr([t for t in tuples])
#print repr([a for a in allitems])
- #print "\n---------------------------------\n"
+ #print "\n---------------------------------\n"
nodes = {}
edges = _EdgeCollection()
@@ -138,7 +180,7 @@ class QueueDependencySorter(object):
if not nodes.has_key(item):
node = _Node(item)
nodes[item] = node
-
+
for t in tuples:
if t[0] is t[1]:
if allow_self_cycles:
@@ -188,16 +230,19 @@ class QueueDependencySorter(object):
for childnode in edges.pop_node(node):
queue.append(childnode)
return self._create_batched_tree(output)
-
+
def _create_batched_tree(self, nodes):
- """given a list of nodes from a topological sort, organizes the nodes into a tree structure,
- with as many non-dependent nodes set as silbings to each other as possible."""
+ """Given a list of nodes from a topological sort, organize the
+ nodes into a tree structure, with as many non-dependent nodes
+ set as siblings to each other as possible.
+ """
+
if not len(nodes):
return None
# a list of all currently independent subtrees as a tuple of
# (root_node, set_of_all_tree_nodes, set_of_all_cycle_nodes_in_tree)
- # order of the list has no semantics for the algorithmic
+ # order of the list has no semantics for the algorithmic
independents = []
# in reverse topological order
for node in reversed(nodes):
@@ -212,7 +257,7 @@ class QueueDependencySorter(object):
if nodealldeps:
# iterate over independent node indexes in reverse order so we can efficiently remove them
for index in xrange(len(independents)-1,-1,-1):
- child, childsubtree, childcycles = independents[index]
+ child, childsubtree, childcycles = independents[index]
# if there is a dependency between this node and an independent node
if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)):
# prepend child to nodes children
@@ -231,7 +276,7 @@ class QueueDependencySorter(object):
# used prepend [0:0] instead of extend to maintain exact behaviour of previous implementation
head.children[0:0] = [i[0] for i in independents]
return head
-
+
def _find_cycles(self, edges):
involved_in_cycles = util.Set()
cycles = {}
@@ -241,7 +286,7 @@ class QueueDependencySorter(object):
cycle = []
elif node is goal:
return True
-
+
for (n, key) in edges.edges_by_parent(node):
if key in cycle:
continue
@@ -259,7 +304,7 @@ class QueueDependencySorter(object):
else:
cycles[x] = cycset
cycle.pop()
-
+
for parent in edges.get_parents():
traverse(parent)
@@ -269,4 +314,3 @@ class QueueDependencySorter(object):
if edge[0] in cycle and edge[1] in cycle:
edgecollection.append(edge)
yield edgecollection
-
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 7260442700..bf5b359575 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -5,7 +5,7 @@
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__all__ = [ 'TypeEngine', 'TypeDecorator', 'NullTypeEngine',
- 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'FLOAT', 'DECIMAL',
+ 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'FLOAT', 'DECIMAL',
'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN', 'String', 'Integer', 'SmallInteger','Smallinteger',
'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'Binary', 'Boolean', 'Unicode', 'PickleType', 'NULLTYPE',
'SMALLINT', 'DATE', 'TIME'
@@ -26,58 +26,76 @@ class AbstractType(object):
return _impl_cache[self]
except KeyError:
return _impl_cache.setdefault(self, {})
+
impl_dict = property(_get_impl_dict)
def copy_value(self, value):
return value
+
def compare_values(self, x, y):
return x is y
+
def is_mutable(self):
return False
+
def get_dbapi_type(self, dbapi):
- """return the corresponding type object from the underlying DBAPI, if any.
-
- this can be useful for calling setinputsizes(), for example."""
+ """Return the corresponding type object from the underlying DBAPI, if any.
+
+ This can be useful for calling ``setinputsizes()``, for example.
+ """
+
return None
+
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, ",".join(["%s=%s" % (k, getattr(self, k)) for k in inspect.getargspec(self.__init__)[0][1:]]))
-
+
class TypeEngine(AbstractType):
def __init__(self, *args, **params):
pass
+
def engine_impl(self, engine):
- """deprecated; call dialect_impl with a dialect directly."""
+ """Deprecated; call dialect_impl with a dialect directly."""
+
return self.dialect_impl(engine.dialect)
+
def dialect_impl(self, dialect):
try:
return self.impl_dict[dialect]
except KeyError:
return self.impl_dict.setdefault(dialect, dialect.type_descriptor(self))
+
def _get_impl(self):
if hasattr(self, '_impl'):
return self._impl
else:
return NULLTYPE
+
def _set_impl(self, impl):
self._impl = impl
+
impl = property(_get_impl, _set_impl)
+
def get_col_spec(self):
raise NotImplementedError()
+
def convert_bind_param(self, value, dialect):
return value
+
def convert_result_value(self, value, dialect):
return value
+
def adapt(self, cls):
return cls()
-
class TypeDecorator(AbstractType):
def __init__(self, *args, **kwargs):
if not hasattr(self.__class__, 'impl'):
raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
self.impl = self.__class__.impl(*args, **kwargs)
+
def engine_impl(self, engine):
return self.dialect_impl(engine.dialect)
+
def dialect_impl(self, dialect):
try:
return self.impl_dict[dialect]
@@ -95,37 +113,50 @@ class TypeDecorator(AbstractType):
tt.impl = typedesc
self.impl_dict[dialect] = tt
return tt
+
def __getattr__(self, key):
- """proxy all other undefined accessors to the underlying implementation."""
+ """Proxy all other undefined accessors to the underlying implementation."""
+
return getattr(self.impl, key)
+
def get_col_spec(self):
return self.impl.get_col_spec()
+
def convert_bind_param(self, value, dialect):
return self.impl.convert_bind_param(value, dialect)
+
def convert_result_value(self, value, dialect):
return self.impl.convert_result_value(value, dialect)
+
def copy(self):
instance = self.__class__.__new__(self.__class__)
instance.__dict__.update(self.__dict__)
return instance
+
def get_dbapi_type(self, dbapi):
return self.impl.get_dbapi_type(dbapi)
+
def copy_value(self, value):
return self.impl.copy_value(value)
+
def compare_values(self, x, y):
return self.impl.compare_values(x,y)
+
def is_mutable(self):
return self.impl.is_mutable()
class MutableType(object):
- """a mixin that marks a Type as holding a mutable object"""
+ """A mixin that marks a Type as holding a mutable object."""
+
def is_mutable(self):
return True
+
def copy_value(self, value):
raise NotImplementedError()
+
def compare_values(self, x, y):
return x == y
-
+
def to_instance(typeobj):
if typeobj is None:
return NULLTYPE
@@ -133,10 +164,11 @@ def to_instance(typeobj):
return typeobj()
else:
return typeobj
+
def adapt_type(typeobj, colspecs):
if isinstance(typeobj, type):
typeobj = typeobj()
-
+
for t in typeobj.__class__.__mro__[0:-1]:
try:
impltype = colspecs[t]
@@ -147,19 +179,21 @@ def adapt_type(typeobj, colspecs):
# couldnt adapt - so just return the type itself
# (it may be a user-defined type)
return typeobj
- # if we adapted the given generic type to a database-specific type,
+ # if we adapted the given generic type to a database-specific type,
# but it turns out the originally given "generic" type
# is actually a subclass of our resulting type, then we were already
# were given a more specific type than that required; so use that.
if (issubclass(typeobj.__class__, impltype)):
return typeobj
return typeobj.adapt(impltype)
-
+
class NullTypeEngine(TypeEngine):
def get_col_spec(self):
raise NotImplementedError()
+
def convert_bind_param(self, value, dialect):
return value
+
def convert_result_value(self, value, dialect):
return value
@@ -169,130 +203,161 @@ class String(TypeEngine):
return super(String, cls).__new__(cls, *args, **kwargs)
else:
return super(String, TEXT).__new__(TEXT, *args, **kwargs)
+
def __init__(self, length = None):
self.length = length
+
def adapt(self, impltype):
return impltype(length=self.length)
+
def convert_bind_param(self, value, dialect):
if not dialect.convert_unicode or value is None or not isinstance(value, unicode):
return value
else:
return value.encode(dialect.encoding)
+
def convert_result_value(self, value, dialect):
if not dialect.convert_unicode or value is None or isinstance(value, unicode):
return value
else:
return value.decode(dialect.encoding)
+
def get_dbapi_type(self, dbapi):
return dbapi.STRING
+
def compare_values(self, x, y):
return x == y
-
+
class Unicode(TypeDecorator):
impl = String
+
def convert_bind_param(self, value, dialect):
if value is not None and isinstance(value, unicode):
return value.encode(dialect.encoding)
else:
return value
+
def convert_result_value(self, value, dialect):
if value is not None and not isinstance(value, unicode):
return value.decode(dialect.encoding)
else:
return value
-
+
class Integer(TypeEngine):
- """integer datatype"""
+ """Integer datatype."""
+
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
-
+
class SmallInteger(Integer):
- """ smallint datatype """
+ """Smallint datatype."""
+
pass
+
Smallinteger = SmallInteger
-
+
class Numeric(TypeEngine):
def __init__(self, precision = 10, length = 2):
self.precision = precision
self.length = length
+
def adapt(self, impltype):
return impltype(precision=self.precision, length=self.length)
+
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
class Float(Numeric):
def __init__(self, precision = 10):
self.precision = precision
+
def adapt(self, impltype):
return impltype(precision=self.precision)
class DateTime(TypeEngine):
- """implements a type for datetime.datetime() objects"""
+ """Implement a type for ``datetime.datetime()`` objects."""
+
def __init__(self, timezone=False):
self.timezone = timezone
+
def adapt(self, impltype):
return impltype(timezone=self.timezone)
+
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
-
+
class Date(TypeEngine):
- """implements a type for datetime.date() objects"""
+ """Implement a type for ``datetime.date()`` objects."""
+
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
class Time(TypeEngine):
- """implements a type for datetime.time() objects"""
+ """Implement a type for ``datetime.time()`` objects."""
+
def __init__(self, timezone=False):
self.timezone = timezone
+
def adapt(self, impltype):
return impltype(timezone=self.timezone)
+
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
class Binary(TypeEngine):
def __init__(self, length=None):
self.length = length
+
def convert_bind_param(self, value, dialect):
if value is not None:
return dialect.dbapi().Binary(value)
else:
return None
+
def convert_result_value(self, value, dialect):
return value
+
def adapt(self, impltype):
return impltype(length=self.length)
+
def get_dbapi_type(self, dbapi):
return dbapi.BINARY
class PickleType(MutableType, TypeDecorator):
impl = Binary
+
def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, mutable=True):
self.protocol = protocol
self.pickler = pickler or pickle
self.mutable = mutable
super(PickleType, self).__init__()
+
def convert_result_value(self, value, dialect):
if value is None:
return None
buf = self.impl.convert_result_value(value, dialect)
return self.pickler.loads(str(buf))
+
def convert_bind_param(self, value, dialect):
if value is None:
return None
return self.impl.convert_bind_param(self.pickler.dumps(value, self.protocol), dialect)
+
def copy_value(self, value):
if self.mutable:
return self.pickler.loads(self.pickler.dumps(value, self.protocol))
else:
return value
+
def compare_values(self, x, y):
if self.mutable:
return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol)
else:
return x is y
+
def is_mutable(self):
return self.mutable
-
+
class Boolean(TypeEngine):
pass
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 09abfae55b..6f7304bea3 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -35,15 +35,17 @@ def to_set(x):
return x
def flatten_iterator(x):
- """given an iterator of which further sub-elements may also be iterators,
- flatten the sub-elements into a single iterator."""
+ """Given an iterator of which further sub-elements may also be
+ iterators, flatten the sub-elements into a single iterator.
+ """
+
for elem in x:
if hasattr(elem, '__iter__'):
for y in flatten_iterator(elem):
yield y
else:
yield elem
-
+
def reversed(seq):
try:
return __builtin__.reversed(seq)
@@ -58,10 +60,12 @@ def reversed(seq):
class ArgSingleton(type):
instances = {}
+
def dispose_static(self, *args):
hashkey = (self, args)
#if hashkey in ArgSingleton.instances:
del ArgSingleton.instances[hashkey]
+
def __call__(self, *args):
hashkey = (self, args)
try:
@@ -72,7 +76,8 @@ class ArgSingleton(type):
return instance
def get_cls_kwargs(cls):
- """return the full set of legal kwargs for the given cls"""
+ """Return the full set of legal kwargs for the given `cls`."""
+
kw = []
for c in cls.__mro__:
cons = c.__init__
@@ -81,15 +86,19 @@ def get_cls_kwargs(cls):
if vn != 'self':
kw.append(vn)
return kw
-
+
class SimpleProperty(object):
- """a "default" property accessor."""
+ """A *default* property accessor."""
+
def __init__(self, key):
self.key = key
+
def __set__(self, obj, value):
setattr(obj, self.key, value)
+
def __delete__(self, obj):
delattr(obj, self.key)
+
def __get__(self, obj, owner):
if obj is None:
return self
@@ -97,60 +106,80 @@ class SimpleProperty(object):
return getattr(obj, self.key)
class OrderedProperties(object):
- """
- An object that maintains the order in which attributes are set upon it.
- also provides an iterator and a very basic getitem/setitem interface to those attributes.
-
+ """An object that maintains the order in which attributes are set upon it.
+
+ Also provides an iterator and a very basic getitem/setitem
+ interface to those attributes.
+
(Not really a dict, since it iterates over values, not keys. Not really
a list, either, since each value must have a key associated; hence there is
no append or extend.)
"""
+
def __init__(self):
self.__dict__['_data'] = OrderedDict()
+
def __len__(self):
return len(self._data)
+
def __iter__(self):
return self._data.itervalues()
+
def __add__(self, other):
return list(self) + list(other)
+
def __setitem__(self, key, object):
self._data[key] = object
+
def __getitem__(self, key):
return self._data[key]
+
def __delitem__(self, key):
del self._data[key]
+
def __setattr__(self, key, object):
self._data[key] = object
+
_data = property(lambda s:s.__dict__['_data'])
+
def __getattr__(self, key):
try:
return self._data[key]
except KeyError:
raise AttributeError(key)
+
def __contains__(self, key):
return key in self._data
+
def get(self, key, default=None):
if self.has_key(key):
return self[key]
else:
return default
+
def keys(self):
return self._data.keys()
+
def has_key(self, key):
return self._data.has_key(key)
+
def clear(self):
self._data.clear()
-
+
class OrderedDict(dict):
- """A Dictionary that returns keys/values/items in the order they were added"""
+ """A Dictionary that returns keys/values/items in the order they were added."""
+
def __init__(self, d=None, **kwargs):
self._list = []
self.update(d, **kwargs)
+
def keys(self):
return list(self._list)
+
def clear(self):
self._list = []
dict.clear(self)
+
def update(self, d=None, **kwargs):
# d can be a dict or sequence of keys/values
if d:
@@ -162,61 +191,77 @@ class OrderedDict(dict):
self.__setitem__(key, value)
if kwargs:
self.update(kwargs)
+
def setdefault(self, key, value):
if not self.has_key(key):
self.__setitem__(key, value)
return value
else:
return self.__getitem__(key)
+
def values(self):
return [self[key] for key in self._list]
+
def __iter__(self):
return iter(self._list)
+
def itervalues(self):
return iter([self[key] for key in self._list])
- def iterkeys(self):
+
+ def iterkeys(self):
return self.__iter__()
+
def iteritems(self):
return iter([(key, self[key]) for key in self.keys()])
+
def __delitem__(self, key):
try:
del self._list[self._list.index(key)]
except ValueError:
raise KeyError(key)
dict.__delitem__(self, key)
+
def __setitem__(self, key, object):
if not self.has_key(key):
self._list.append(key)
dict.__setitem__(self, key, object)
+
def __getitem__(self, key):
return dict.__getitem__(self, key)
class ThreadLocal(object):
- """an object in which attribute access occurs only within the context of the current thread"""
+ """An object in which attribute access occurs only within the context of the current thread."""
+
def __init__(self):
self.__dict__['_tdict'] = {}
+
def __delattr__(self, key):
try:
del self._tdict["%d_%s" % (thread.get_ident(), key)]
except KeyError:
raise AttributeError(key)
+
def __getattr__(self, key):
try:
return self._tdict["%d_%s" % (thread.get_ident(), key)]
except KeyError:
raise AttributeError(key)
+
def __setattr__(self, key, value):
self._tdict["%d_%s" % (thread.get_ident(), key)] = value
class DictDecorator(dict):
- """a Dictionary that delegates items not found to a second wrapped dictionary."""
+ """A Dictionary that delegates items not found to a second wrapped dictionary."""
+
def __init__(self, decorate):
self.decorate = decorate
+
def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError:
return self.decorate[key]
+
def __repr__(self):
return dict.__repr__(self) + repr(self.decorate)
@@ -246,7 +291,8 @@ class OrderedSet(Set):
super(OrderedSet, self).clear()
self._list=[]
- def __iter__(self): return iter(self._list)
+ def __iter__(self):
+ return iter(self._list)
def update(self, iterable):
add = self.add
@@ -255,24 +301,31 @@ class OrderedSet(Set):
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, self._list)
+
__str__ = __repr__
def union(self, other):
result = self.__class__(self)
result.update(other)
return result
+
__or__ = union
+
def intersection(self, other):
return self.__class__([a for a in self if a in other])
+
__and__ = intersection
+
def symmetric_difference(self, other):
result = self.__class__([a for a in self if a not in other])
result.update([a for a in other if a not in self])
return result
+
__xor__ = symmetric_difference
def difference(self, other):
return self.__class__([a for a in self if a not in other])
+
__sub__ = difference
__ior__ = update
@@ -281,6 +334,7 @@ class OrderedSet(Set):
super(OrderedSet, self).intersection_update(other)
self._list = [ a for a in self._list if a in other]
return self
+
__iand__ = intersection_update
def symmetric_difference_update(self, other):
@@ -288,12 +342,14 @@ class OrderedSet(Set):
self._list = [ a for a in self._list if a in self]
self._list += [ a for a in other._list if a in self]
return self
+
__ixor__ = symmetric_difference_update
def difference_update(self, other):
super(OrderedSet, self).difference_update(other)
self._list = [ a for a in self._list if a in self]
return self
+
__isub__ = difference_update
class UniqueAppender(object):
@@ -304,20 +360,25 @@ class UniqueAppender(object):
elif hasattr(data, 'add'):
self._data_appender = data.add
self.set = Set()
+
def append(self, item):
if item not in self.set:
self.set.add(item)
self._data_appender(item)
-
+
class ScopedRegistry(object):
- """a Registry that can store one or multiple instances of a single class
- on a per-thread scoped basis, or on a customized scope
-
- createfunc - a callable that returns a new object to be placed in the registry
- scopefunc - a callable that will return a key to store/retrieve an object,
- defaults to thread.get_ident for thread-local objects. use a value like
- lambda: True for application scope.
+ """A Registry that can store one or multiple instances of a single
+ class on a per-thread scoped basis, or on a customized scope.
+
+ createfunc
+ a callable that returns a new object to be placed in the registry
+
+ scopefunc
+ a callable that will return a key to store/retrieve an object,
+ defaults to ``thread.get_ident`` for thread-local objects. Use
+ a value like ``lambda: True`` for application scope.
"""
+
def __init__(self, createfunc, scopefunc=None):
self.createfunc = createfunc
if scopefunc is None:
@@ -325,20 +386,22 @@ class ScopedRegistry(object):
else:
self.scopefunc = scopefunc
self.registry = {}
+
def __call__(self):
key = self._get_key()
try:
return self.registry[key]
except KeyError:
return self.registry.setdefault(key, self.createfunc())
+
def set(self, obj):
self.registry[self._get_key()] = obj
+
def clear(self):
try:
del self.registry[self._get_key()]
except KeyError:
pass
+
def _get_key(self):
return self.scopefunc()
-
-