From: Mike Bayer Date: Sat, 14 Jul 2007 23:36:17 +0000 (+0000) Subject: - merged trunk r2880-r2901 (slightly manually for 2900-2901) X-Git-Tag: rel_0_4_6~111 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=60f60b9a603f8c2053f35f0855ae237f4d6b9a44;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - merged trunk r2880-r2901 (slightly manually for 2900-2901) - merges "bind" argument change - merges join fixes for [ticket:185] - removed all "engine"/"connectable"/"bind_to"/"engine_or_url" arguments/attributes --- diff --git a/CHANGES b/CHANGES index 194fefc5b0..229239fa42 100644 --- a/CHANGES +++ b/CHANGES @@ -118,17 +118,32 @@ - better error message for NoSuchColumnError [ticket:607] - finally figured out how to get setuptools version in, available as sqlalchemy.__version__ [ticket:428] + - the various "engine" arguments, such as "engine", "connectable", + "engine_or_url", "bind_to", etc. are all present, but deprecated. + they all get replaced by the single term "bind". you also + set the "bind" of MetaData using + metadata.bind = - ext - iteration over dict association proxies is now dict-like, not InstrumentedList-like (e.g. over keys instead of values) - association proxies no longer bind tightly to source collections [ticket:597], and are constructed with a thunk instead + - added selectone_by() to assignmapper - orm - forwards-compatibility with 0.4: added one(), first(), and - all() to Query + all() to Query. almost all Query functionality from 0.4 is + present in 0.3.9 for forwards-compat purposes. + - reset_joinpoint() really really works this time, promise ! lets + you re-join from the root: + query.join(['a', 'b']).filter().reset_joinpoint().\ + join(['a', 'c']).filter().all() + in 0.4 all join() calls start from the "root" - added synchronization to the mapper() construction step, to avoid thread collections when pre-existing mappers are compiling in a different thread [ticket:613] + - a warning is issued by Mapper when two primary key columns of the + same name are munged into a single attribute. this happens frequently + when mapping to joins (or inheritance). - synonym() properties are fully supported by all Query joining/ with_parent operations [ticket:598] - fixed very stupid bug when deleting items with many-to-many @@ -151,6 +166,14 @@ - DynamicMetaData has been renamed to ThreadLocalMetaData. the DynamicMetaData name is deprecated and is an alias for ThreadLocalMetaData or a regular MetaData if threadlocal=False + - composite primary key is represented as a non-keyed set to allow for + composite keys consisting of cols with the same name; occurs within a + Join. helps inheritance scenarios formulate correct PK. + - improved ability to get the "correct" and most minimal set of primary key + columns from a join, equating foreign keys and otherwise equated columns. + this is also mostly to help inheritance scenarios formulate the best + choice of primary key columns. [ticket:185] + - added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute() - some enhancements to "column targeting", the ability to match a column to a "corresponding" column in another selectable. this affects mostly ORM ability to map to complex joins @@ -202,6 +225,8 @@ - the fix in "schema" above fixes reflection of foreign keys from an alt-schema table to a public schema table - sqlite + - rearranged dialect initialization so it has time to warn about pysqlite1 + being too old. - sqlite better handles datetime/date/time objects mixed and matched with various Date/Time/DateTime columns - string PK column inserts dont get overwritten with OID [ticket:603] @@ -210,9 +235,6 @@ - fix port option handling for pyodbc [ticket:634] - now able to reflect start and increment values for identity columns - preliminary support for using scope_identity() with pyodbc - -- extensions - - added selectone_by() to assignmapper 0.3.8 - engines diff --git a/doc/build/content/tutorial.txt b/doc/build/content/tutorial.txt index 464d3044bc..615f275f9a 100644 --- a/doc/build/content/tutorial.txt +++ b/doc/build/content/tutorial.txt @@ -105,7 +105,7 @@ With `metadata` as our established home for tables, lets make a Table for it: >>> users_table = Table('users', metadata, ... Column('user_id', Integer, primary_key=True), ... Column('user_name', String(40)), - ... Column('password', String(10)) + ... Column('password', String(15)) ... ) As you might have guessed, we have just defined a table named `users` which has three columns: `user_id` (which is a primary key column), `user_name` and `password`. Currently it is just an object that doesn't necessarily correspond to an existing table in our database. To actually create the table, we use the `create()` method. To make it interesting, we will have SQLAlchemy echo the SQL statements it sends to the database, by setting the `echo` flag on the `Engine` associated with our `MetaData`: @@ -116,7 +116,7 @@ As you might have guessed, we have just defined a table named `users` which has CREATE TABLE users ( user_id INTEGER NOT NULL, user_name VARCHAR(40), - password VARCHAR(10), + password VARCHAR(15), PRIMARY KEY (user_id) ) ... diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 7e59564445..233fe050ac 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -1068,25 +1068,11 @@ class MySQLDialect(ansisql.ANSIDialect): return self._default_schema_name def has_table(self, connection, table_name, schema=None): - # TODO: this does not work for table names that contain multibyte characters. - - # http://dev.mysql.com/doc/refman/5.0/en/error-messages-server.html - - # Error: 1146 SQLSTATE: 42S02 (ER_NO_SUCH_TABLE) - # Message: Table '%s.%s' doesn't exist - - # Error: 1046 SQLSTATE: 3D000 (ER_NO_DB_ERROR) - # Message: No database selected - - try: - name = schema and ("%s.%s" % (schema, table_name)) or table_name - connection.execute("DESCRIBE `%s`" % name) - return True - except exceptions.SQLError, e: - if e.orig.args[0] in (1146, 1046): - return False - else: - raise + if schema is not None: + st = 'SHOW TABLE STATUS FROM `%s` LIKE %%s' % schema + else: + st = 'SHOW TABLE STATUS LIKE %s' + return connection.execute(st, table_name).rowcount != 0 def get_version_info(self, connectable): if hasattr(connectable, 'connect'): @@ -1102,34 +1088,36 @@ class MySQLDialect(ansisql.ANSIDialect): return tuple(version) def reflecttable(self, connection, table): - # reference: http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html - cs = connection.execute("show variables like 'lower_case_table_names'").fetchone()[1] - if isinstance(cs, array): - cs = cs.tostring() - case_sensitive = int(cs) == 0 + """Load column definitions from the server.""" - decode_from = connection.execute("show variables like 'character_set_results'").fetchone()[1] + decode_from = self._detect_charset(connection) + + # reference: + # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html + row = _compat_fetch(connection.execute( + "SHOW VARIABLES LIKE 'lower_case_table_names'"), + one=True, charset=decode_from) + if not row: + case_sensitive = True + else: + case_sensitive = row[1] in ('0', 'OFF' 'off') if not case_sensitive: table.name = table.name.lower() table.metadata.tables[table.name]= table + try: - c = connection.execute("describe " + table.fullname, {}) + rp = connection.execute("describe " + self._escape_table_name(table), + {}) except: - raise exceptions.NoSuchTableError(table.name) - found_table = False - while True: - row = c.fetchone() - if row is None: - break - #print "row! " + repr(row) - if not found_table: - found_table = True - - # these can come back as unicode if use_unicode=1 in the mysql connection - (name, type, nullable, primary_key, default) = (row[0], str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4]) - if not isinstance(name, unicode): - name = name.decode(decode_from) + raise exceptions.NoSuchTableError(table.fullname) + + for row in _compat_fetch(rp, charset=decode_from): + (name, type, nullable, primary_key, default) = \ + (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4]) + + # leave column names as unicode + name = name.decode(decode_from) match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type) col_type = match.group(1) @@ -1137,7 +1125,6 @@ class MySQLDialect(ansisql.ANSIDialect): extra_1 = match.group(3) extra_2 = match.group(4) - #print "coltype: " + repr(col_type) + " args: " + repr(args) + "extras:" + repr(extra_1) + ' ' + repr(extra_2) try: coltype = ischema_names[col_type] except KeyError: @@ -1162,32 +1149,24 @@ class MySQLDialect(ansisql.ANSIDialect): colargs= [] if default: if col_type == 'timestamp' and default == 'CURRENT_TIMESTAMP': - arg = sql.text(default) - else: - arg = default - colargs.append(schema.PassiveDefault(arg)) + default = sql.text(default) + colargs.append(schema.PassiveDefault(default)) table.append_column(schema.Column(name, coltype, *colargs, **dict(primary_key=primary_key, nullable=nullable, ))) - tabletype = self.moretableinfo(connection, table=table) + tabletype = self.moretableinfo(connection, table, decode_from) table.kwargs['mysql_engine'] = tabletype - if not found_table: - raise exceptions.NoSuchTableError(table.name) + def moretableinfo(self, connection, table, charset=None): + """SHOW CREATE TABLE to get foreign key/table options.""" - def moretableinfo(self, connection, table): - """runs SHOW CREATE TABLE to get foreign key/options information about the table. - - """ - c = connection.execute("SHOW CREATE TABLE " + table.fullname, {}) - desc_fetched = c.fetchone()[1] - - if not isinstance(desc_fetched, basestring): - # may get array.array object here, depending on version (such as mysql 4.1.14 vs. 4.1.11) - desc_fetched = desc_fetched.tostring() - desc = desc_fetched.strip() + rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {}) + row = _compat_fetch(rp, one=True, charset=charset) + if not row: + raise exceptions.NoSuchTableError(table.fullname) + desc = row[1].strip() tabletype = '' lastparen = re.search(r'\)[^\)]*\Z', desc) @@ -1207,9 +1186,68 @@ class MySQLDialect(ansisql.ANSIDialect): return tabletype + def _escape_table_name(self, table): + if table.schema is not None: + return '`%s`.`%s`' % (table.schema. table.name) + else: + return '`%s`' % table.name + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Note: MySQL-python 1.2.1c7 seems to ignore changes made + # on a connection via set_character_set() + + rs = connection.execute("show variables like 'character_set%%'") + opts = dict([(row[0], row[1]) for row in _compat_fetch(rs)]) + + if 'character_set_results' in opts: + return opts['character_set_results'] + try: + return connection.connection.character_set_name() + except AttributeError: + # < 1.2.1 final MySQL-python drivers have no charset support + if 'character_set' in opts: + return opts['character_set'] + else: + warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python. MySQL-python >= 1.2.2 is recommended. Assuming latin1.")) + return 'latin1' + +def _compat_fetch(rp, one=False, charset=None): + """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" + + if one: + return _MySQLPythonRowProxy(rp.fetchone(), charset) + else: + return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()] + + +class _MySQLPythonRowProxy(object): + """Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings.""" + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = charset + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, array): + item = item.tostring() + if self.charset and isinstance(item, unicode): + return item.encode(self.charset) + else: + return item + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, array): + item = item.tostring() + if self.charset and isinstance(item, unicode): + return item.encode(self.charset) + else: + return item + + class MySQLCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): - """hey ho MySQL supports almost no types at all for CAST""" if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)): return super(MySQLCompiler, self).visit_cast(cast) else: diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 0fd31928af..7ccf38c4bb 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -161,13 +161,13 @@ class SQLiteDialect(ansisql.ANSIDialect): ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs) def vers(num): return tuple([int(x) for x in num.split('.')]) - self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) if self.dbapi is not None: sqlite_ver = self.dbapi.version_info if sqlite_ver < (2,1,'3'): warnings.warn(RuntimeWarning("The installed version of pysqlite2 (%s) is out-dated, and will cause errors in some cases. Version 2.1.3 or greater is recommended." % '.'.join([str(subver) for subver in sqlite_ver]))) if vers(self.dbapi.sqlite_version) < vers("3.3.13"): warnings.warn(RuntimeWarning("The installed version of sqlite (%s) is out-dated, and will cause errors in some cases. Version 3.3.13 or greater is recommended." % self.dbapi.sqlite_version)) + self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) def dbapi(cls): try: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 1143fbf59a..6c23422a1f 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -406,7 +406,7 @@ class Compiled(sql.ClauseVisitor): defaults. """ - def __init__(self, dialect, statement, parameters, engine=None): + def __init__(self, dialect, statement, parameters, bind=None): """Construct a new ``Compiled`` object. statement @@ -426,13 +426,13 @@ class Compiled(sql.ClauseVisitor): can either be the string names of columns or ``_ColumnClause`` objects. - engine - Optional Engine to compile this statement against. + bind + Optional Engine or Connection to compile this statement against. """ self.dialect = dialect self.statement = statement self.parameters = parameters - self.engine = engine + self.bind = bind self.can_execute = statement.supports_execution() def compile(self): @@ -465,9 +465,9 @@ class Compiled(sql.ClauseVisitor): def execute(self, *multiparams, **params): """Execute this compiled object.""" - e = self.engine + e = self.bind if e is None: - raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.") + raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.") return e.execute_compiled(self, *multiparams, **params) def scalar(self, *multiparams, **params): @@ -691,7 +691,7 @@ class Connection(Connectable): return self.execute(object, *multiparams, **params).scalar() def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs) + return self.dialect.compiler(statement, parameters, bind=self.engine, **kwargs) def execute(self, object, *multiparams, **params): for c in type(object).__mro__: @@ -945,14 +945,14 @@ class Engine(Connectable): connection.close() def _func(self): - return sql._FunctionGenerator(engine=self) + return sql._FunctionGenerator(bind=self) func = property(_func) def text(self, text, *args, **kwargs): """Return a sql.text() object for performing literal queries.""" - return sql.text(text, engine=self, *args, **kwargs) + return sql.text(text, bind=self, *args, **kwargs) def _run_visitor(self, visitorcallable, element, connection=None, **kwargs): if connection is None: @@ -1014,7 +1014,7 @@ class Engine(Connectable): return connection.execute_compiled(compiled, *multiparams, **params) def compiler(self, statement, parameters, **kwargs): - return self.dialect.compiler(statement, parameters, engine=self, **kwargs) + return self.dialect.compiler(statement, parameters, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" @@ -1510,7 +1510,7 @@ class DefaultRunner(schema.SchemaVisitor): return None def exec_default_sql(self, default): - c = sql.select([default.arg]).compile(engine=self.connection) + c = sql.select([default.arg]).compile(bind=self.connection) return self.connection.execute_compiled(c).scalar() def visit_column_onupdate(self, onupdate): diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index 9940610f75..90399b7b5c 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -266,7 +266,7 @@ directly. The engine's ``execute`` method corresponds to the one of a DBAPI cursor, and returns a ``ResultProxy`` that has ``fetch`` methods you would also see on a cursor:: - >>> rp = db.engine.execute('select name, email from users order by name') + >>> rp = db.bind.execute('select name, email from users order by name') >>> for name, email in rp.fetchall(): print name, email Bhargan Basepair basepair+nospam@example.edu Joe Student student@example.edu @@ -497,9 +497,10 @@ class SqlSoup: self.schema = None def engine(self): - return self._metadata._engine + return self._metadata.bind engine = property(engine) + bind = engine def delete(self, *args, **kwargs): objectstore.delete(*args, **kwargs) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 3a5f60c272..5718d49dd5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.util import ExtensionCarrier from sqlalchemy.orm import sync from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, EXT_PASS, MapperExtension, SynonymProperty -import weakref +import weakref, warnings __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry'] @@ -543,9 +543,11 @@ class Mapper(object): # against the "mapped_table" of this mapper. equivalent_columns = self._get_equivalent_columns() - primary_key = sql.ColumnCollection() + primary_key = sql.ColumnSet() for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]): + #primary_key.add(col) + #continue c = self.mapped_table.corresponding_column(col, raiseerr=False) if c is None: for cc in equivalent_columns[col]: @@ -690,6 +692,8 @@ class Mapper(object): prop = prop.copy() prop.set_parent(self) self.__props[column_key] = prop + if column in self.primary_key and prop.columns[-1] in self.primary_key: + warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'. Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), column_key))) prop.columns.append(column) self.__log("appending to existing ColumnProperty %s" % (column_key)) else: @@ -1360,7 +1364,7 @@ class Mapper(object): statement = table.delete(clause) c = connection.execute(statement, delete) if c.supports_sane_rowcount() and c.rowcount != len(delete): - raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete))) + raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete))) for obj in deleted_objects: for mapper in object_mapper(obj).iterate_to_root(): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 46ce21a793..17d8feabb1 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -38,21 +38,21 @@ class SessionTransaction(object): def _begin(self): return SessionTransaction(self.session, self) - def add(self, connectable): - if self.connections.has_key(connectable.engine): + def add(self, bind): + if self.connections.has_key(bind.engine): raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") - return self.get_or_add(connectable) + return self.get_or_add(bind) - def get_or_add(self, connectable): + def get_or_add(self, bind): # we reference the 'engine' attribute on the given object, which in the case of # Connection, ProxyEngine, Engine, whatever, should return the original # "Engine" object that is handling the connection. - if self.connections.has_key(connectable.engine): - return self.connections[connectable.engine][0] - e = connectable.engine - c = connectable.contextual_connect() + if self.connections.has_key(bind.engine): + return self.connections[bind.engine][0] + e = bind.engine + c = bind.contextual_connect() if not self.connections.has_key(e): - self.connections[e] = (c, c.begin(), c is not connectable) + self.connections[e] = (c, c.begin(), c is not bind) return self.connections[e][0] def commit(self): @@ -99,13 +99,13 @@ class Session(object): of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module. """ - def __init__(self, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False): + def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False): if import_session is not None: self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map) else: self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map) - self.bind_to = bind_to + self.bind = bind or bind_to self.binds = {} self.echo_uow = echo_uow self.weak_identity_map = weak_identity_map @@ -122,6 +122,8 @@ class Session(object): def _set_echo_uow(self, value): self.uow.echo = value echo_uow = property(_get_echo_uow,_set_echo_uow) + + bind_to = property(lambda self:self.bind) def create_transaction(self, **kwargs): """Return a new ``SessionTransaction`` corresponding to an @@ -213,23 +215,23 @@ class Session(object): return _class_mapper(class_, entity_name = entity_name) - def bind_mapper(self, mapper, bindto): + def bind_mapper(self, mapper, bind): """Bind the given `mapper` to the given ``Engine`` or ``Connection``. All subsequent operations involving this ``Mapper`` will use the - given `bindto`. + given `bind`. """ - self.binds[mapper] = bindto + self.binds[mapper] = bind - def bind_table(self, table, bindto): + def bind_table(self, table, bind): """Bind the given `table` to the given ``Engine`` or ``Connection``. All subsequent operations involving this ``Table`` will use the - given `bindto`. + given `bind`. """ - self.binds[table] = bindto + self.binds[table] = bind def get_bind(self, mapper): """Return the ``Engine`` or ``Connection`` which is used to execute @@ -259,17 +261,17 @@ class Session(object): """ if mapper is None: - return self.bind_to + return self.bind elif self.binds.has_key(mapper): return self.binds[mapper] elif self.binds.has_key(mapper.mapped_table): return self.binds[mapper.mapped_table] - elif self.bind_to is not None: - return self.bind_to + elif self.bind is not None: + return self.bind else: - e = mapper.mapped_table.engine + e = mapper.mapped_table.bind if e is None: - raise exceptions.InvalidRequestError("Could not locate any Engine bound to mapper '%s'" % str(mapper)) + raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) return e def query(self, mapper_or_class, entity_name=None, **kwargs): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 9b9ae801a9..897f397b61 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -57,21 +57,27 @@ class SchemaItem(object): return None - def _get_engine(self): + def _get_engine(self, raiseerr=False): """Return the engine or None if no engine.""" - return self._derived_metadata().engine - - def get_engine(self, connectable=None): - """Return the engine or raise an error if no engine.""" - - if connectable is not None: - return connectable - e = self._get_engine() - if e is not None: - return e + if raiseerr: + m = self._derived_metadata() + e = m and m.bind or None + if e is None: + raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") + else: + return e else: - raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine") + m = self._derived_metadata() + return m and m.bind or None + + def get_engine(self): + """Return the engine or raise an error if no engine. + + Deprecated. use the "bind" attribute. + """ + + return self._get_engine(raiseerr=True) def _set_casing_strategy(self, kwargs, keyname='case_sensitive'): """Set the "case_sensitive" argument sent via keywords to the item's constructor. @@ -121,9 +127,9 @@ class SchemaItem(object): return self.__case_sensitive case_sensitive = property(_get_case_sensitive) - engine = property(lambda s:s._get_engine()) metadata = property(lambda s:s._derived_metadata()) - + bind = property(lambda s:s._get_engine()) + def _get_table_key(name, schema): if schema is None: return name @@ -159,7 +165,7 @@ class _TableSingleton(sql._FigureVisitName): if autoload_with: autoload_with.reflecttable(table) else: - metadata.get_engine().reflecttable(table) + metadata._get_engine(raiseerr=True).reflecttable(table) except exceptions.NoSuchTableError: del metadata.tables[key] raise @@ -269,7 +275,9 @@ class Table(SchemaItem, sql.TableClause): self.schema = kwargs.pop('schema', None) self.indexes = util.Set() self.constraints = util.Set() + self._columns = sql.ColumnCollection() self.primary_key = PrimaryKeyConstraint() + self._foreign_keys = util.OrderedSet() self.quote = kwargs.pop('quote', False) self.quote_schema = kwargs.pop('quote_schema', False) if self.schema is not None: @@ -289,6 +297,11 @@ class Table(SchemaItem, sql.TableClause): key = property(lambda self:_get_table_key(self.name, self.schema)) + def _export_columns(self, columns=None): + # override FromClause's collection initialization logic; TableClause and Table + # implement it differently + pass + def _get_case_sensitive_schema(self): try: return getattr(self, '_case_sensitive_schema') @@ -343,30 +356,37 @@ class Table(SchemaItem, sql.TableClause): else: return [] - def exists(self, connectable=None): + def exists(self, bind=None, connectable=None): """Return True if this table exists.""" - if connectable is None: - connectable = self.get_engine() + if connectable is not None: + bind = connectable + + if bind is None: + bind = self._get_engine(raiseerr=True) def do(conn): e = conn.engine return e.dialect.has_table(conn, self.name, schema=self.schema) - return connectable.run_callable(do) + return bind.run_callable(do) - def create(self, connectable=None, checkfirst=False): + def create(self, bind=None, checkfirst=False, connectable=None): """Issue a ``CREATE`` statement for this table. See also ``metadata.create_all()``.""" - self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) + if connectable is not None: + bind = connectable + self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self]) - def drop(self, connectable=None, checkfirst=False): + def drop(self, bind=None, checkfirst=False, connectable=None): """Issue a ``DROP`` statement for this table. See also ``metadata.drop_all()``.""" - self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) + if connectable is not None: + bind = connectable + self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self]) def tometadata(self, metadata, schema=None): """Return a copy of this ``Table`` associated with a different ``MetaData``.""" @@ -527,8 +547,16 @@ class Column(SchemaItem, sql._ColumnClause): return self.table.metadata def _get_engine(self): - return self.table.engine + return self.table.bind + def references(self, column): + """return true if this column references the given column via foreign key""" + for fk in self.foreign_keys: + if fk.column is column: + return True + else: + return False + def append_foreign_key(self, fk): fk._set_parent(self) @@ -744,7 +772,7 @@ class DefaultGenerator(SchemaItem): def __init__(self, for_update=False, metadata=None): self.for_update = for_update - self._metadata = metadata + self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata') def _derived_metadata(self): try: @@ -763,8 +791,10 @@ class DefaultGenerator(SchemaItem): else: self.column.default = self - def execute(self, connectable=None, **kwargs): - return self.get_engine(connectable=connectable).execute_default(self, **kwargs) + def execute(self, bind=None, **kwargs): + if bind is None: + bind = self._get_engine(raiseerr=True) + return bind.execute_default(self, **kwargs) def __repr__(self): return "DefaultGenerator()" @@ -822,12 +852,15 @@ class Sequence(DefaultGenerator): super(Sequence, self)._set_parent(column) column.sequence = self - def create(self, connectable=None, checkfirst=True): - self.get_engine(connectable=connectable).create(self, checkfirst=checkfirst) - return self + def create(self, bind=None, checkfirst=True): + if bind is None: + bind = self._get_engine(raiseerr=True) + bind.create(self, checkfirst=checkfirst) - def drop(self, connectable=None, checkfirst=True): - self.get_engine(connectable=connectable).drop(self, checkfirst=checkfirst) + def drop(self, bind=None, checkfirst=True): + if bind is None: + bind = self._get_engine(raiseerr=True) + bind.drop(self, checkfirst=checkfirst) class Constraint(SchemaItem): @@ -1022,14 +1055,14 @@ class Index(SchemaItem): if connectable is not None: connectable.create(self) else: - self.get_engine().create(self) + self._get_engine(raiseerr=True).create(self) return self def drop(self, connectable=None): if connectable is not None: connectable.drop(self) else: - self.get_engine().drop(self) + self._get_engine(raiseerr=True).drop(self) def __str__(self): return repr(self) @@ -1045,31 +1078,23 @@ class MetaData(SchemaItem): __visit_name__ = 'metadata' - def __init__(self, engine_or_url=None, **kwargs): + def __init__(self, bind=None, **kwargs): """create a new MetaData object. - - url - a string or URL instance which will be passed to create_engine(), - along with \**kwargs - this MetaData will be bound to the resulting - engine. - engine - an Engine instance to which this MetaData will be bound. - + bind + an Engine, or a string or URL instance which will be passed + to create_engine(), along with \**kwargs - this MetaData will + be bound to the resulting engine. + case_sensitive popped from \**kwargs, indicates default case sensitive setting for all contained objects. defaults to True. """ - - if engine_or_url is None: - # limited backwards compatability - engine_or_url = kwargs.get('url', None) or kwargs.get('engine', None) + self.tables = {} - self._engine = None self._set_casing_strategy(kwargs) - if engine_or_url: - self.connect(engine_or_url, **kwargs) + self.bind = bind def __getstate__(self): return {'tables':self.tables, 'casesensitive':self._case_sensitive_setting} @@ -1077,38 +1102,32 @@ class MetaData(SchemaItem): def __setstate__(self, state): self.tables = state['tables'] self._case_sensitive_setting = state['casesensitive'] - self._engine = None + self._bind = None def is_bound(self): """return True if this MetaData is bound to an Engine.""" - return self._engine is not None + return self._bind is not None - def connect(self, engine_or_url, **kwargs): + def connect(self, bind, **kwargs): """bind this MetaData to an Engine. + + DEPRECATED. use metadata.bind = or metadata.bind = . - engine_or_url + bind a string, URL or Engine instance. If a string or URL, will be passed to create_engine() along with \**kwargs to produce the engine which to connect to. otherwise connects directly to the given Engine. - + """ from sqlalchemy.engine.url import URL - if isinstance(engine_or_url, (basestring, URL)): - self._engine = sqlalchemy.create_engine(engine_or_url, **kwargs) + if isinstance(bind, (basestring, URL)): + self._bind = sqlalchemy.create_engine(bind, **kwargs) else: - self._engine = engine_or_url + self._bind = bind - def _get_engine(self): - # we are checking is_bound() because _engine wires - # into SchemaItem's _engine mechanism, which raises an error, - # whereas we just want to return None. - if not self.is_bound(): - return None - return self._engine - - engine = property(_get_engine, connect) + bind = property(lambda self:self._bind, connect, doc="""an Engine or Connection to which this MetaData is bound. this is a settable property as well.""") def clear(self): self.tables.clear() @@ -1129,47 +1148,64 @@ class MetaData(SchemaItem): def _get_parent(self): return None - def create_all(self, connectable=None, tables=None, checkfirst=True): + def create_all(self, bind=None, tables=None, checkfirst=True, connectable=None): """Create all tables stored in this metadata. This will conditionally create tables depending on if they do not yet exist in the database. + bind + A ``Connectable`` used to access the database; if None, uses + the existing bind on this ``MetaData``, if any. + connectable - A ``Connectable`` used to access the database; or use the engine - bound to this ``MetaData``. + deprecated. synonymous with "bind" tables Optional list of tables, which is a subset of the total tables in the ``MetaData`` (others are ignored). """ - if connectable is None: - connectable = self.get_engine() - connectable.create(self, checkfirst=checkfirst, tables=tables) + if connectable is not None: + bind = connectable + if bind is None: + bind = self._get_engine(raiseerr=True) + bind.create(self, checkfirst=checkfirst, tables=tables) - def drop_all(self, connectable=None, tables=None, checkfirst=True): + def drop_all(self, bind=None, tables=None, checkfirst=True, connectable=None): """Drop all tables stored in this metadata. This will conditionally drop tables depending on if they currently exist in the database. + bind + A ``Connectable`` used to access the database; if None, uses + the existing bind on this ``MetaData``, if any. + connectable - A ``Connectable`` used to access the database; or use the engine - bound to this ``MetaData``. + deprecated. synonymous with "bind" tables Optional list of tables, which is a subset of the total tables in the ``MetaData`` (others are ignored). """ - if connectable is None: - connectable = self.get_engine() - connectable.drop(self, checkfirst=checkfirst, tables=tables) + if connectable is not None: + bind = connectable + if bind is None: + bind = self._get_engine(raiseerr=True) + bind.drop(self, checkfirst=checkfirst, tables=tables) def _derived_metadata(self): return self + def _get_engine(self, raiseerr=False): + if not self.is_bound(): + if raiseerr: + raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") + else: + return None + return self._bind class ThreadLocalMetaData(MetaData): """Build upon ``MetaData`` to provide the capability to bind to @@ -1209,13 +1245,15 @@ thread-local basis. for e in self.__engines.values(): e.dispose() - def _get_engine(self): + def _get_engine(self, raiseerr=False): if hasattr(self.context, '_engine'): return self.context._engine else: - return None - - engine = property(_get_engine, connect) + if raiseerr: + raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.") + else: + return None + bind = property(_get_engine, connect) class SchemaVisitor(sql.ClauseVisitor): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 32c20bc10f..c5eeda9c93 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -218,12 +218,12 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs): and oracle supports "nowait" which translates to ``FOR UPDATE NOWAIT``. - engine=None - an ``Engine`` instance to which the resulting ``Select`` + bind=None + an ``Engine`` or ``Connection`` instance to which the resulting ``Select`` object will be bound. The ``Select`` object will otherwise - automatically bind to whatever ``Engine`` instances can be located + automatically bind to whatever ``Connectable`` instances can be located within its contained ``ClauseElement`` members. - + limit=None a numerical value which usually compiles to a ``LIMIT`` expression in the resulting select. Databases that don't support ``LIMIT`` @@ -708,7 +708,7 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False): else: return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique) -def text(text, engine=None, *args, **kwargs): +def text(text, bind=None, *args, **kwargs): """Create literal text to be inserted into a query. When constructing a query from a ``select()``, ``update()``, @@ -723,9 +723,9 @@ def text(text, engine=None, *args, **kwargs): to specify bind parameters; they will be compiled to their engine-specific format. - engine - An optional engine to be used for this text query. - + bind + An optional connection or engine to be used for this text query. + bindparams A list of ``bindparam()`` instances which can be used to define the types and/or initial values for the bind parameters within @@ -742,7 +742,7 @@ def text(text, engine=None, *args, **kwargs): """ - return _TextClause(text, engine=engine, *args, **kwargs) + return _TextClause(text, bind=bind, *args, **kwargs) def null(): """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement.""" @@ -1040,22 +1040,20 @@ class ClauseElement(object): """ try: - if self._engine is not None: - return self._engine + if self._bind is not None: + return self._bind except AttributeError: pass for f in self._get_from_objects(): if f is self: continue - engine = f.engine + engine = f.bind if engine is not None: return engine else: return None - - engine = property(lambda s: s._find_engine(), - doc="""Attempts to locate a Engine within this ClauseElement - structure, or returns None if none found.""") + + bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""") def execute(self, *multiparams, **params): """Compile and execute this ``ClauseElement``.""" @@ -1064,7 +1062,7 @@ class ClauseElement(object): compile_params = multiparams[0] else: compile_params = params - return self.compile(engine=self.engine, parameters=compile_params).execute(*multiparams, **params) + return self.compile(bind=self.bind, parameters=compile_params).execute(*multiparams, **params) def scalar(self, *multiparams, **params): """Compile and execute this ``ClauseElement``, returning the @@ -1073,7 +1071,7 @@ class ClauseElement(object): return self.execute(*multiparams, **params).scalar() - def compile(self, engine=None, parameters=None, compiler=None, dialect=None): + def compile(self, bind=None, parameters=None, compiler=None, dialect=None): """Compile this SQL expression. Uses the given ``Compiler``, or the given ``AbstractDialect`` @@ -1102,10 +1100,10 @@ class ClauseElement(object): if compiler is None: if dialect is not None: compiler = dialect.compiler(self, parameters) - elif engine is not None: - compiler = engine.compiler(self, parameters) - elif self.engine is not None: - compiler = self.engine.compiler(self, parameters) + elif bind is not None: + compiler = bind.compiler(self, parameters) + elif self.bind is not None: + compiler = self.bind.compiler(self, parameters) if compiler is None: import sqlalchemy.ansisql as ansisql @@ -1473,6 +1471,25 @@ class ColumnCollection(util.OrderedProperties): # "True" value (i.e. a BinaryClause...) return col in util.Set(self) +class ColumnSet(util.OrderedSet): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __add__(self, other): + return list(self) + list(other) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement. @@ -1616,7 +1633,7 @@ class FromClause(Selectable): """) oid_column = property(_get_oid_column) - def _export_columns(self): + def _export_columns(self, columns=None): """Initialize column collections. The collections include the primary key, foreign keys, list of @@ -1629,14 +1646,17 @@ class FromClause(Selectable): its parent ``Selectable`` is this ``FromClause``. """ - if hasattr(self, '_columns'): + if hasattr(self, '_columns') and columns is None: # TODO: put a mutex here ? this is a key place for threading probs return self._columns = ColumnCollection() - self._primary_key = ColumnCollection() + self._primary_key = ColumnSet() self._foreign_keys = util.Set() self._orig_cols = {} - for co in self._flatten_exportable_columns(): + + if columns is None: + columns = self._flatten_exportable_columns() + for co in columns: cp = self._proxy_column(co) for ci in cp.orig_set: cx = self._orig_cols.get(ci) @@ -1756,8 +1776,8 @@ class _TextClause(ClauseElement): __visit_name__ = 'textclause' - def __init__(self, text = "", engine=None, bindparams=None, typemap=None): - self._engine = engine + def __init__(self, text = "", bind=None, bindparams=None, typemap=None): + self._bind = bind self.bindparams = {} self.typemap = typemap if typemap is not None: @@ -1883,7 +1903,7 @@ class _CalculatedClause(ColumnElement): def __init__(self, name, *clauses, **kwargs): self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) - self._engine = kwargs.get('engine', None) + self._bind = kwargs.get('bind', None) self.group = kwargs.pop('group', True) self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) if self.group: @@ -1928,7 +1948,7 @@ class _Function(_CalculatedClause, FromClause): self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] kwargs['operator'] = ',' - self._engine = kwargs.get('engine', None) + self._bind = kwargs.get('bind', None) _CalculatedClause.__init__(self, name, **kwargs) for c in clauses: self.append(c) @@ -2091,15 +2111,38 @@ class Join(FromClause): encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace')) def _init_primary_key(self): - pkcol = util.OrderedSet() - for col in self._flatten_exportable_columns(): - if col.primary_key: - pkcol.add(col) - for col in list(pkcol): - for f in col.foreign_keys: - if f.column in pkcol: - pkcol.remove(col) - self.primary_key.extend(pkcol) + pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key]) + + equivs = {} + def add_equiv(a, b): + for x, y in ((a, b), (b, a)): + if x in equivs: + equivs[x].add(y) + else: + equivs[x] = util.Set([y]) + + class BinaryVisitor(ClauseVisitor): + def visit_binary(self, binary): + if binary.operator == '=': + add_equiv(binary.left, binary.right) + BinaryVisitor().traverse(self.onclause) + + for col in pkcol: + for fk in col.foreign_keys: + if fk.column in pkcol: + add_equiv(col, fk.column) + + omit = util.Set() + for col in pkcol: + p = col + for c in equivs.get(col, util.Set()): + if p.references(c) or (c.primary_key and not p.primary_key): + omit.add(p) + p = c + + self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit]) + + primary_key = property(lambda s:s.__primary_key) def _locate_oid_column(self): return self.left.oid_column @@ -2185,7 +2228,11 @@ class Join(FromClause): collist.append(c) self.__folded_equivalents = collist return self.__folded_equivalents - + + folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, " + "equated columns folded into one column, where 'equated' means they are " + "equated to each other in the ON clause of this join.") + def select(self, whereclause = None, fold_equivalents=False, **kwargs): """Create a ``Select`` from this ``Join``. @@ -2205,13 +2252,13 @@ class Join(FromClause): """ if fold_equivalents: - collist = self._get_folded_equivalents() + collist = self.folded_equivalents else: collist = [self.left, self.right] return select(collist, whereclause, from_obj=[self], **kwargs) - engine = property(lambda s:s.left.engine or s.right.engine) + bind = property(lambda s:s.left.bind or s.right.bind) def alias(self, name=None): """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it. @@ -2299,7 +2346,7 @@ class Alias(FromClause): def _group_parenthesized(self): return False - engine = property(lambda s: s.selectable.engine) + bind = property(lambda s: s.selectable.bind) class _Grouping(ColumnElement): def __init__(self, elem): @@ -2492,12 +2539,8 @@ class TableClause(FromClause): super(TableClause, self).__init__(name) self.name = self.fullname = name self.encodedname = self.name.encode('ascii', 'backslashreplace') - self._columns = ColumnCollection() - self._foreign_keys = util.OrderedSet() - self._primary_key = ColumnCollection() - for c in columns: - self.append_column(c) self._oid_column = _ColumnClause('oid', self, _is_oid=True) + self._export_columns(columns) def _clone(self): # TableClause is immutable @@ -2513,6 +2556,10 @@ class TableClause(FromClause): def _locate_oid_column(self): return self._oid_column + def _proxy_column(self, c): + self.append_column(c) + return c + def _orig_columns(self): try: return self._orig_cols @@ -2530,7 +2577,7 @@ class TableClause(FromClause): return [c for c in self.c] else: return [] - + def _exportable_columns(self): raise NotImplementedError() @@ -2571,12 +2618,12 @@ class TableClause(FromClause): class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" - def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, connectable=None, scalar=False, engine=None): + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, scalar=False): self.use_labels = use_labels self.for_update = for_update self._limit = limit self._offset = offset - self._engine = connectable or engine + self._bind = bind self.is_scalar = scalar if self.is_scalar: # allow corresponding_column to return None @@ -3001,14 +3048,14 @@ class Select(_SelectBaseMixin, FromClause): object, or searched within the from clauses for one. """ - if self._engine is not None: - return self._engine + if self._bind is not None: + return self._bind for f in self._froms: if f is self: continue - e = f.engine + e = f.bind if e is not None: - self._engine = e + self._bind = e return e # look through the columns (largely synomous with looking # through the FROMs except in the case of _CalculatedClause/_Function) @@ -3016,9 +3063,9 @@ class Select(_SelectBaseMixin, FromClause): for c in cc.columns: if getattr(c, 'table', None) is self: continue - e = c.engine + e = c.bind if e is not None: - self._engine = e + self._bind = e return e return None @@ -3078,7 +3125,7 @@ class _UpdateBase(ClauseElement): return parameters def _find_engine(self): - return self.table.engine + return self.table.bind class Insert(_UpdateBase): def __init__(self, table, values=None): diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 028baf8aaa..e711de3a3e 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -10,6 +10,7 @@ except ImportError: import dummy_thread as thread import dummy_threading as threading +from sqlalchemy import exceptions import md5 import sys import warnings @@ -159,6 +160,15 @@ def duck_type_collection(specimen, default=None): else: return default +def assert_arg_type(arg, argtype, name): + if isinstance(arg, argtype): + return arg + else: + if isinstance(argtype, tuple): + raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) + else: + raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) + def warn_exception(func, *args, **kwargs): """executes the given function, catches all exceptions and converts to a warning.""" try: diff --git a/test/engine/alltests.py b/test/engine/alltests.py index 722f06256d..a34a82ed75 100644 --- a/test/engine/alltests.py +++ b/test/engine/alltests.py @@ -7,6 +7,7 @@ def suite(): # connectivity, execution 'engine.parseconnect', 'engine.pool', + 'engine.bind', 'engine.reconnect', 'engine.execute', 'engine.metadata', diff --git a/test/engine/bind.py b/test/engine/bind.py new file mode 100644 index 0000000000..b928d3dd86 --- /dev/null +++ b/test/engine/bind.py @@ -0,0 +1,167 @@ +"""tests the "bind" attribute/argument across schema, SQL, and ORM sessions, +including the deprecated versions of these arguments""" + +import testbase +import unittest, sys, datetime +import tables +db = testbase.db +from sqlalchemy import * + +class BindTest(testbase.PersistTest): + def test_create_drop_explicit(self): + metadata = MetaData() + table = Table('test_table', metadata, + Column('foo', Integer)) + for bind in ( + testbase.db, + testbase.db.connect() + ): + for args in [ + ([], {'bind':bind}), + ([bind], {}) + ]: + metadata.create_all(*args[0], **args[1]) + assert table.exists(*args[0], **args[1]) + metadata.drop_all(*args[0], **args[1]) + table.create(*args[0], **args[1]) + table.drop(*args[0], **args[1]) + assert not table.exists(*args[0], **args[1]) + + def test_create_drop_err(self): + metadata = MetaData() + table = Table('test_table', metadata, + Column('foo', Integer)) + + for meth in [ + metadata.create_all, + table.exists, + metadata.drop_all, + table.create, + table.drop, + ]: + try: + meth() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "This SchemaItem is not connected to any Engine or Connection." + + def test_create_drop_bound(self): + + for meta in (MetaData,ThreadLocalMetaData): + for bind in ( + testbase.db, + testbase.db.connect() + ): + metadata = meta() + table = Table('test_table', metadata, + Column('foo', Integer)) + metadata.bind = bind + assert metadata.bind is table.bind is bind + metadata.create_all() + assert table.exists() + metadata.drop_all() + table.create() + table.drop() + assert not table.exists() + + metadata = meta() + table = Table('test_table', metadata, + Column('foo', Integer)) + + metadata.connect(bind) + assert metadata.bind is table.bind is bind + metadata.create_all() + assert table.exists() + metadata.drop_all() + table.create() + table.drop() + assert not table.exists() + + def test_create_drop_constructor_bound(self): + for bind in ( + testbase.db, + testbase.db.connect() + ): + for args in ( + ([bind], {}), + ([], {'bind':bind}), + ): + metadata = MetaData(*args[0], **args[1]) + table = Table('test_table', metadata, + Column('foo', Integer)) + + assert metadata.bind is table.bind is bind + metadata.create_all() + assert table.exists() + metadata.drop_all() + table.create() + table.drop() + assert not table.exists() + + + def test_clauseelement(self): + metadata = MetaData() + table = Table('test_table', metadata, + Column('foo', Integer)) + metadata.create_all(bind=testbase.db) + try: + for elem in [ + table.select, + lambda **kwargs:func.current_timestamp(**kwargs).select(), +# func.current_timestamp().select, + lambda **kwargs:text("select * from test_table", **kwargs) + ]: + for bind in ( + testbase.db, + testbase.db.connect() + ): + e = elem(bind=bind) + assert e.bind is bind + e.execute() + + try: + e = elem() + assert e.bind is None + e.execute() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "This Compiled object is not bound to any Engine or Connection." + + finally: + metadata.drop_all(bind=testbase.db) + + def test_session(self): + from sqlalchemy.orm import create_session, mapper + metadata = MetaData() + table = Table('test_table', metadata, + Column('foo', Integer, primary_key=True), + Column('data', String(30))) + class Foo(object): + pass + mapper(Foo, table) + metadata.create_all(bind=testbase.db) + try: + for bind in (testbase.db, testbase.db.connect()): + for args in ({'bind':bind},): + sess = create_session(**args) + assert sess.bind is bind + f = Foo() + sess.save(f) + sess.flush() + assert sess.get(Foo, f.foo) is f + + sess = create_session() + f = Foo() + sess.save(f) + try: + sess.flush() + assert False + except exceptions.InvalidRequestError, e: + assert str(e).startswith("Could not locate any Engine or Connection bound to mapper") + + finally: + metadata.drop_all(bind=testbase.db) + + +if __name__ == '__main__': + testbase.main() \ No newline at end of file diff --git a/test/engine/metadata.py b/test/engine/metadata.py index c3c6441ef1..28b0535a57 100644 --- a/test/engine/metadata.py +++ b/test/engine/metadata.py @@ -7,7 +7,7 @@ class MetaDataTest(testbase.PersistTest): metadata = MetaData() t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True), Column('col2', String(20))) - metadata.engine = testbase.db + metadata.bind = testbase.db metadata.create_all() try: assert t1.count().scalar() == 0 diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 672d1bcd7c..842be682d6 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -237,7 +237,7 @@ class ReflectionTest(PersistTest): PRIMARY KEY(id) )""") try: - metadata = MetaData(engine=testbase.db) + metadata = MetaData(bind=testbase.db) book = Table('book', metadata, autoload=True) assert book.c.id in book.primary_key assert book.c.series not in book.primary_key @@ -258,7 +258,7 @@ class ReflectionTest(PersistTest): PRIMARY KEY(id, isbn) )""") try: - metadata = MetaData(engine=testbase.db) + metadata = MetaData(bind=testbase.db) book = Table('book', metadata, autoload=True) assert book.c.id in book.primary_key assert book.c.isbn in book.primary_key @@ -363,17 +363,17 @@ class ReflectionTest(PersistTest): def test_pickle(): meta.connect(testbase.db) meta2 = pickle.loads(pickle.dumps(meta)) - assert meta2.engine is None + assert meta2.bind is None return (meta2.tables['mytable'], meta2.tables['othertable']) def test_pickle_via_reflect(): # this is the most common use case, pickling the results of a # database reflection - meta2 = MetaData(engine=testbase.db) + meta2 = MetaData(bind=testbase.db) t1 = Table('mytable', meta2, autoload=True) t2 = Table('othertable', meta2, autoload=True) meta3 = pickle.loads(pickle.dumps(meta2)) - assert meta3.engine is None + assert meta3.bind is None assert meta3.tables['mytable'] is not t1 return (meta3.tables['mytable'], meta3.tables['othertable']) diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 1437cde1fb..3f61ec3691 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -170,7 +170,7 @@ class FlushTest(testbase.ORMTest): ) admins = Table('admin', metadata, - Column('id', Integer, primary_key=True), + Column('admin_id', Integer, primary_key=True), Column('user_id', Integer, ForeignKey('users.id')) ) @@ -237,6 +237,86 @@ class FlushTest(testbase.ORMTest): a.password = 'sadmin' sess.flush() assert user_roles.count().scalar() == 1 - + +class DistinctPKTest(testbase.ORMTest): + """test the construction of mapper.primary_key when an inheriting relationship + joins on a column other than primary key column.""" + keep_data = True + + def define_tables(self, metadata): + global person_table, employee_table, Person, Employee + + person_table = Table("persons", metadata, + Column("id", Integer, primary_key=True), + Column("name", String(80)), + ) + + employee_table = Table("employees", metadata, + Column("id", Integer, primary_key=True), + Column("salary", Integer), + Column("person_id", Integer, ForeignKey("persons.id")), + ) + + class Person(object): + def __init__(self, name): + self.name = name + + class Employee(Person): pass + + def insert_data(self): + person_insert = person_table.insert() + person_insert.execute(id=1, name='alice') + person_insert.execute(id=2, name='bob') + + employee_insert = employee_table.insert() + employee_insert.execute(id=2, salary=250, person_id=1) # alice + employee_insert.execute(id=3, salary=200, person_id=2) # bob + + def test_implicit(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper) + try: + print class_mapper(Employee).primary_key + assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id] + assert False + except RuntimeWarning, e: + assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name." + + def test_explicit_props(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id}) + self._do_test(True) + + def test_explicit_composite_pk(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id]) + try: + self._do_test(True) + assert False + except RuntimeWarning, e: + assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name." + + def test_explicit_pk(self): + person_mapper = mapper(Person, person_table) + mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id]) + self._do_test(False) + + def _do_test(self, composite): + session = create_session() + query = session.query(Employee) + + if composite: + alice1 = query.get([1,2]) + bob = query.get([2,3]) + alice2 = query.get([1,2]) + else: + alice1 = query.get(1) + bob = query.get(2) + alice2 = query.get(1) + + assert alice1.name == alice2.name == 'alice' + assert bob.name == 'bob' + + if __name__ == "__main__": testbase.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 26a0b80a44..eb0d110a16 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -150,6 +150,9 @@ class MapperTest(MapperSuperTest): def bad_expunge(foo): raise Exception("this exception should be stated as a warning") + import warnings + warnings.filterwarnings("always", r".*this exception should be stated as a warning") + sess.expunge = bad_expunge try: Foo(_sa_session=sess) @@ -660,7 +663,7 @@ class DeferredTest(MapperSuperTest): o2 = l[2] print o2.description - orderby = str(orders.default_order_by()[0].compile(engine=db)) + orderby = str(orders.default_order_by()[0].compile(bind=db)) self.assert_sql(db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3}) diff --git a/test/orm/query.py b/test/orm/query.py index f52d90c5d8..c0b0c8f84b 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -54,7 +54,7 @@ class QueryTest(testbase.ORMTest): def define_tables(self, meta): # a slight dirty trick here. meta.tables = metadata.tables - metadata.connect(meta.engine) + metadata.connect(meta.bind) def setup_mappers(self): mapper(User, users, properties={ diff --git a/test/orm/relationships.py b/test/orm/relationships.py index d95b7b4adf..80fe147275 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -561,19 +561,19 @@ class TypeMatchTest(testbase.ORMTest): def define_tables(self, metadata): global a, b, c, d a = Table("a", metadata, - Column('id', Integer, primary_key=True), + Column('aid', Integer, primary_key=True), Column('data', String(30))) b = Table("b", metadata, - Column('id', Integer, primary_key=True), - Column("a_id", Integer, ForeignKey("a.id")), + Column('bid', Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) c = Table("c", metadata, - Column('id', Integer, primary_key=True), - Column("b_id", Integer, ForeignKey("b.id")), + Column('cid', Integer, primary_key=True), + Column("b_id", Integer, ForeignKey("b.bid")), Column('data', String(30))) d = Table("d", metadata, - Column('id', Integer, primary_key=True), - Column("a_id", Integer, ForeignKey("a.id")), + Column('did', Integer, primary_key=True), + Column("a_id", Integer, ForeignKey("a.aid")), Column('data', String(30))) def test_o2m_oncascade(self): class A(object):pass diff --git a/test/sql/defaults.py b/test/sql/defaults.py index 2eeaef7cc1..07363a402e 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -25,26 +25,26 @@ class DefaultTest(PersistTest): # select "count(1)" returns different results on different DBs # also correct for "current_date" compatible as column default, value differences - currenttime = func.current_date(type=Date, engine=db); + currenttime = func.current_date(type=Date, bind=db); if is_oracle: ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar() - f = select([func.count(1) + 5], engine=db).scalar() - f2 = select([func.count(1) + 14], engine=db).scalar() + f = select([func.count(1) + 5], bind=db).scalar() + f2 = select([func.count(1) + 14], bind=db).scalar() # TODO: engine propigation across nested functions not working - currenttime = func.trunc(currenttime, literal_column("'DAY'"), engine=db) + currenttime = func.trunc(currenttime, literal_column("'DAY'"), bind=db) def1 = currenttime def2 = func.trunc(text("sysdate"), literal_column("'DAY'")) deftype = Date elif use_function_defaults: - f = select([func.count(1) + 5], engine=db).scalar() - f2 = select([func.count(1) + 14], engine=db).scalar() + f = select([func.count(1) + 5], bind=db).scalar() + f2 = select([func.count(1) + 14], bind=db).scalar() def1 = currenttime def2 = text("current_date") deftype = Date ts = db.func.current_date().scalar() else: - f = select([func.count(1) + 5], engine=db).scalar() - f2 = select([func.count(1) + 14], engine=db).scalar() + f = select([func.count(1) + 5], bind=db).scalar() + f2 = select([func.count(1) + 14], bind=db).scalar() def1 = def2 = "3" ts = 3 deftype = Integer @@ -257,7 +257,7 @@ class SequenceTest(PersistTest): @testbase.supported('postgres', 'oracle') def test_implicit_sequence_exec(self): - s = Sequence("my_sequence", metadata=testbase.db) + s = Sequence("my_sequence", metadata=MetaData(testbase.db)) s.create() try: x = s.execute() @@ -266,9 +266,9 @@ class SequenceTest(PersistTest): s.drop() @testbase.supported('postgres', 'oracle') - def test_explicit_sequence_exec(self): + def teststandalone_explicit(self): s = Sequence("my_sequence") - s.create(testbase.db) + s.create(bind=testbase.db) try: x = s.execute(testbase.db) self.assert_(x == 1) diff --git a/test/sql/labels.py b/test/sql/labels.py index 968e75dfc3..384fead50b 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -17,7 +17,7 @@ class LabelTypeTest(testbase.PersistTest): class LongLabelsTest(testbase.PersistTest): def setUpAll(self): global metadata, table1 - metadata = MetaData(engine=testbase.db) + metadata = MetaData(testbase.db) table1 = Table("some_large_named_table", metadata, Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True), Column("this_is_the_data_column", String(30)) diff --git a/test/sql/query.py b/test/sql/query.py index 76e07881b0..77cf91d434 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -235,7 +235,7 @@ class QueryTest(PersistTest): self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') - r = text("select * from query_users where user_id=2", engine=testbase.db).execute().fetchone() + r = text("select * from query_users where user_id=2", bind=testbase.db).execute().fetchone() self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack') diff --git a/test/sql/selectable.py b/test/sql/selectable.py index 50b2fa6b4a..bcf70bd206 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -1,176 +1,231 @@ -"""tests that various From objects properly export their columns, as well as --useable primary keys and foreign keys. Full relational algebra depends on --every selectable unit behaving nicely with others..""" - -import testbase -import unittest, sys, datetime -from sqlalchemy import * -from testbase import Table, Column - -db = testbase.db -metadata = MetaData(db) - - -table = Table('table1', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', String(20)), - Column('col3', Integer), - Column('colx', Integer), - -) - -table2 = Table('table2', metadata, - Column('col1', Integer, primary_key=True), - Column('col2', Integer, ForeignKey('table1.col1')), - Column('col3', String(20)), - Column('coly', Integer), -) - -class SelectableTest(testbase.AssertMixin): - def testdistance(self): - s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) - - # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far - #assert s.corresponding_column(table.c.col1) is s.c.col1 - assert s.corresponding_column(s.c.col1) is s.c.col1 - assert s.corresponding_column(s.c.c1) is s.c.c1 - - def testjoinagainstself(self): - jj = select([table.c.col1.label('bar_col1')]) - jjj = join(table, jj, table.c.col1==jj.c.bar_col1) - - # test column directly agaisnt itself - assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 - - assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 - - # test alias of the join, targets the column with the least - # "distance" between the requested column and the returned column - # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than - # there is from j2.c.bar_col1 to table.c.col1) - j2 = jjj.alias('foo') - assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 - - - def testjoinagainstjoin(self): - j = outerjoin(table, table2, table.c.col1==table2.c.col2) - jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') - jjj = join(table, jj, table.c.col1==jj.c.bar_col1) - assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 - - j2 = jjj.alias('foo') - print j2.corresponding_column(jjj.c.table1_col1) - assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 - - assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 - - def testtablealias(self): - a = table.alias('a') - - j = join(a, table2) - - criterion = a.c.col1 == table2.c.col2 - print - print str(j) - self.assert_(criterion.compare(j.onclause)) - - def testunion(self): - # tests that we can correspond a column in a Select statement with a certain Table, against - # a column in a Union where one of its underlying Selects matches to that same Table - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ) - s1 = table.select(use_labels=True) - s2 = table2.select(use_labels=True) - print ["%d %s" % (id(c),c.key) for c in u.c] - c = u.corresponding_column(s1.c.table1_col2) - print "%d %s" % (id(c), c.key) - print id(u.corresponding_column(s1.c.table1_col2).table) - print id(u.c.col2.table) - assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 - assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 - - def testaliasunion(self): - # same as testunion, except its an alias of the union - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ).alias('analias') - s1 = table.select(use_labels=True) - s2 = table2.select(use_labels=True) - assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 - assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 - assert u.corresponding_column(s2.c.table2_coly) is u.c.coly - assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly - - def testselectunion(self): - # like testaliasunion, but off a Select off the union. - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ).alias('analias') - s = select([u]) - s1 = table.select(use_labels=True) - s2 = table2.select(use_labels=True) - assert s.corresponding_column(s1.c.table1_col2) is s.c.col2 - assert s.corresponding_column(s2.c.table2_col2) is s.c.col2 - - def testunionagainstjoin(self): - # same as testunion, except its an alias of the union - u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( - select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) - ).alias('analias') - j1 = table.join(table2) - assert u.corresponding_column(j1.c.table1_colx) is u.c.colx - assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx - - def testjoin(self): - a = join(table, table2) - print str(a.select(use_labels=True)) - b = table2.alias('b') - j = join(a, b) - print str(j) - criterion = a.c.table1_col1 == b.c.col2 - self.assert_(criterion.compare(j.onclause)) - - def testselectalias(self): - a = table.select().alias('a') - print str(a.select()) - j = join(a, table2) - - criterion = a.c.col1 == table2.c.col2 - print criterion - print j.onclause - self.assert_(criterion.compare(j.onclause)) - - def testselectlabels(self): - a = table.select(use_labels=True) - print str(a.select()) - j = join(a, table2) - - criterion = a.c.table1_col1 == table2.c.col2 - print - print str(j) - self.assert_(criterion.compare(j.onclause)) - - def testcolumnlabels(self): - a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')]) - print str(a) - print [c for c in a.columns] - print str(a.select()) - j = join(a, table2) - criterion = a.c.acol1 == table2.c.col2 - print str(j) - self.assert_(criterion.compare(j.onclause)) - - def testselectaliaslabels(self): - a = table2.select(use_labels=True).alias('a') - print str(a.select()) - j = join(a, table) - - criterion = table.c.col1 == a.c.table2_col2 - print str(criterion) - print str(j.onclause) - self.assert_(criterion.compare(j.onclause)) - -if __name__ == "__main__": - testbase.main() - +"""tests that various From objects properly export their columns, as well as +useable primary keys and foreign keys. Full relational algebra depends on +every selectable unit behaving nicely with others..""" + +import testbase +import unittest, sys, datetime +from sqlalchemy import * +from testbase import Table, Column + +db = testbase.db +metadata = MetaData(db) + + +table = Table('table1', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', String(20)), + Column('col3', Integer), + Column('colx', Integer), + +) + +table2 = Table('table2', metadata, + Column('col1', Integer, primary_key=True), + Column('col2', Integer, ForeignKey('table1.col1')), + Column('col3', String(20)), + Column('coly', Integer), +) + +class SelectableTest(testbase.AssertMixin): + def testdistance(self): + s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')]) + + # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far + #assert s.corresponding_column(table.c.col1) is s.c.col1 + assert s.corresponding_column(s.c.col1) is s.c.col1 + assert s.corresponding_column(s.c.c1) is s.c.c1 + + def testjoinagainstself(self): + jj = select([table.c.col1.label('bar_col1')]) + jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + + # test column directly agaisnt itself + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + + assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 + + # test alias of the join, targets the column with the least + # "distance" between the requested column and the returned column + # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than + # there is from j2.c.bar_col1 to table.c.col1) + j2 = jjj.alias('foo') + assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1 + + + def testjoinagainstjoin(self): + j = outerjoin(table, table2, table.c.col1==table2.c.col2) + jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo') + jjj = join(table, jj, table.c.col1==jj.c.bar_col1) + assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 + + j2 = jjj.alias('foo') + print j2.corresponding_column(jjj.c.table1_col1) + assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 + + assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 + + def testtablealias(self): + a = table.alias('a') + + j = join(a, table2) + + criterion = a.c.col1 == table2.c.col2 + print + print str(j) + self.assert_(criterion.compare(j.onclause)) + + def testunion(self): + # tests that we can correspond a column in a Select statement with a certain Table, against + # a column in a Union where one of its underlying Selects matches to that same Table + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ) + s1 = table.select(use_labels=True) + s2 = table2.select(use_labels=True) + print ["%d %s" % (id(c),c.key) for c in u.c] + c = u.corresponding_column(s1.c.table1_col2) + print "%d %s" % (id(c), c.key) + print id(u.corresponding_column(s1.c.table1_col2).table) + print id(u.c.col2.table) + assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 + assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 + + def testaliasunion(self): + # same as testunion, except its an alias of the union + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ).alias('analias') + s1 = table.select(use_labels=True) + s2 = table2.select(use_labels=True) + assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 + assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 + assert u.corresponding_column(s2.c.table2_coly) is u.c.coly + assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly + + def testselectunion(self): + # like testaliasunion, but off a Select off the union. + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ).alias('analias') + s = select([u]) + s1 = table.select(use_labels=True) + s2 = table2.select(use_labels=True) + assert s.corresponding_column(s1.c.table1_col2) is s.c.col2 + assert s.corresponding_column(s2.c.table2_col2) is s.c.col2 + + def testunionagainstjoin(self): + # same as testunion, except its an alias of the union + u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union( + select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly]) + ).alias('analias') + j1 = table.join(table2) + assert u.corresponding_column(j1.c.table1_colx) is u.c.colx + assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx + + def testjoin(self): + a = join(table, table2) + print str(a.select(use_labels=True)) + b = table2.alias('b') + j = join(a, b) + print str(j) + criterion = a.c.table1_col1 == b.c.col2 + self.assert_(criterion.compare(j.onclause)) + + def testselectalias(self): + a = table.select().alias('a') + print str(a.select()) + j = join(a, table2) + + criterion = a.c.col1 == table2.c.col2 + print criterion + print j.onclause + self.assert_(criterion.compare(j.onclause)) + + def testselectlabels(self): + a = table.select(use_labels=True) + print str(a.select()) + j = join(a, table2) + + criterion = a.c.table1_col1 == table2.c.col2 + print + print str(j) + self.assert_(criterion.compare(j.onclause)) + + def testcolumnlabels(self): + a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')]) + print str(a) + print [c for c in a.columns] + print str(a.select()) + j = join(a, table2) + criterion = a.c.acol1 == table2.c.col2 + print str(j) + self.assert_(criterion.compare(j.onclause)) + + def testselectaliaslabels(self): + a = table2.select(use_labels=True).alias('a') + print str(a.select()) + j = join(a, table) + + criterion = table.c.col1 == a.c.table2_col2 + print str(criterion) + print str(j.onclause) + self.assert_(criterion.compare(j.onclause)) + + +class PrimaryKeyTest(testbase.AssertMixin): + def test_join_pk_collapse_implicit(self): + """test that redundant columns in a join get 'collapsed' into a minimal primary key, + which is the root column along a chain of foreign key relationships.""" + + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True)) + c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True)) + d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True)) + + assert c.c.id.references(b.c.id) + assert not d.c.id.references(a.c.id) + + assert list(a.join(b).primary_key) == [a.c.id] + assert list(b.join(c).primary_key) == [b.c.id] + assert list(a.join(b).join(c).primary_key) == [a.c.id] + assert list(b.join(c).join(d).primary_key) == [b.c.id] + assert list(d.join(c).join(b).primary_key) == [b.c.id] + assert list(a.join(b).join(c).join(d).primary_key) == [a.c.id] + + def test_join_pk_collapse_explicit(self): + """test that redundant columns in a join get 'collapsed' into a minimal primary key, + which is the root column along a chain of explicit join conditions.""" + + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) + c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True), Column('x', Integer)) + d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True), Column('x', Integer)) + + print list(a.join(b, a.c.x==b.c.id).primary_key) + assert list(a.join(b, a.c.x==b.c.id).primary_key) == [b.c.id] + assert list(b.join(c, b.c.x==c.c.id).primary_key) == [b.c.id] + assert list(a.join(b).join(c, c.c.id==b.c.x).primary_key) == [a.c.id] + assert list(b.join(c, c.c.x==b.c.id).join(d).primary_key) == [c.c.id] + assert list(b.join(c, c.c.id==b.c.x).join(d).primary_key) == [b.c.id] + assert list(d.join(b, d.c.id==b.c.id).join(c, b.c.id==c.c.x).primary_key) == [c.c.id] + assert list(a.join(b).join(c, c.c.id==b.c.x).join(d).primary_key) == [a.c.id] + + assert list(a.join(b, and_(a.c.id==b.c.id, a.c.x==b.c.id)).primary_key) == [a.c.id] + + def test_init_doesnt_blowitaway(self): + meta = MetaData() + a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer)) + b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer)) + + j = a.join(b) + assert list(j.primary_key) == [a.c.id] + + j.foreign_keys + assert list(j.primary_key) == [a.c.id] + + +if __name__ == "__main__": + testbase.main() + diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 8d5848d1cd..24fbde3a25 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -247,7 +247,7 @@ class BinaryTest(AssertMixin): for stmt in ( binary_table.select(order_by=binary_table.c.primary_id), - text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, engine=testbase.db) + text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db) ): l = stmt.execute().fetchall() print type(stream1), type(l[0]['data']), type(l[0]['data_slice']) diff --git a/test/sql/unicode.py b/test/sql/unicode.py index c426a258c7..f885dc56ba 100644 --- a/test/sql/unicode.py +++ b/test/sql/unicode.py @@ -10,7 +10,7 @@ from testbase import Table, Column class UnicodeSchemaTest(testbase.PersistTest): def setUpAll(self): global metadata, t1, t2 - metadata = MetaData(engine=testbase.db) + metadata = MetaData(testbase.db) t1 = Table('unitable1', metadata, Column(u'méil', Integer, primary_key=True), Column(u'éXXm', Integer), diff --git a/test/testbase.py b/test/testbase.py index 41eb38ddfc..d1e901a2e5 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -10,6 +10,9 @@ import sqlalchemy from sqlalchemy import sql, schema, engine, pool, MetaData from sqlalchemy.orm import clear_mappers +import warnings +warnings.filterwarnings("error") + db = None metadata = None db_uri = None @@ -312,8 +315,11 @@ class ORMTest(AssertMixin): _otest_metadata = MetaData(db) self.define_tables(_otest_metadata) _otest_metadata.create_all() + self.insert_data() def define_tables(self, _otest_metadata): raise NotImplementedError() + def insert_data(self): + pass def get_metadata(self): return _otest_metadata def tearDownAll(self):