From: Federico Caselli Date: Thu, 3 Nov 2022 19:52:21 +0000 (+0100) Subject: Try running pyupgrade on the code X-Git-Tag: rel_2_0_0b4~50^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4eb4ceca36c7ce931ea65ac06d6ed08bf459fc66;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Try running pyupgrade on the code command run is "pyupgrade --py37-plus --keep-runtime-typing --keep-percent-format " pyupgrade will change assert_ to assertTrue. That was reverted since assertTrue does not exists in sqlalchemy fixtures Change-Id: Ie1ed2675c7b11d893d78e028aad0d1576baebb55 --- diff --git a/examples/association/dict_of_sets_with_default.py b/examples/association/dict_of_sets_with_default.py index 96e30c1e28..f515ab975b 100644 --- a/examples/association/dict_of_sets_with_default.py +++ b/examples/association/dict_of_sets_with_default.py @@ -87,7 +87,7 @@ if __name__ == "__main__": # only "A" is referenced explicitly. Using "collections", # we deal with a dict of key/sets of integers directly. - session.add_all([A(collections={"1": set([1, 2, 3])})]) + session.add_all([A(collections={"1": {1, 2, 3}})]) session.commit() a1 = session.query(A).first() diff --git a/examples/materialized_paths/materialized_paths.py b/examples/materialized_paths/materialized_paths.py index ad2a4f4a9a..f458270c72 100644 --- a/examples/materialized_paths/materialized_paths.py +++ b/examples/materialized_paths/materialized_paths.py @@ -86,7 +86,7 @@ class Node(Base): return len(self.path.split(".")) - 1 def __repr__(self): - return "Node(id={})".format(self.id) + return f"Node(id={self.id})" def __str__(self): root_depth = self.depth diff --git a/examples/versioned_rows/versioned_update_old_row.py b/examples/versioned_rows/versioned_update_old_row.py index e1e25704c9..41d3046a74 100644 --- a/examples/versioned_rows/versioned_update_old_row.py +++ b/examples/versioned_rows/versioned_update_old_row.py @@ -45,7 +45,7 @@ class VersionedStartEnd: # reduce some verbosity when we make a new object kw.setdefault("start", current_time() - datetime.timedelta(days=3)) kw.setdefault("end", current_time() + datetime.timedelta(days=3)) - super(VersionedStartEnd, self).__init__(**kw) + super().__init__(**kw) def new_version(self, session): diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 458663bb77..f920860cd8 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -51,7 +51,7 @@ class PyODBCConnector(Connector): dbapi: ModuleType def __init__(self, use_setinputsizes: bool = False, **kw: Any): - super(PyODBCConnector, self).__init__(**kw) + super().__init__(**kw) if use_setinputsizes: self.bind_typing = interfaces.BindTyping.SETINPUTSIZES @@ -83,7 +83,7 @@ class PyODBCConnector(Connector): token = "{%s}" % token.replace("}", "}}") return token - keys = dict((k, check_quote(v)) for k, v in keys.items()) + keys = {k: check_quote(v) for k, v in keys.items()} dsn_connection = "dsn" in keys or ( "host" in keys and "database" not in keys diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 53fe96c9ae..a0049c361e 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -965,190 +965,188 @@ MS_2008_VERSION = (10,) MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) -RESERVED_WORDS = set( - [ - "add", - "all", - "alter", - "and", - "any", - "as", - "asc", - "authorization", - "backup", - "begin", - "between", - "break", - "browse", - "bulk", - "by", - "cascade", - "case", - "check", - "checkpoint", - "close", - "clustered", - "coalesce", - "collate", - "column", - "commit", - "compute", - "constraint", - "contains", - "containstable", - "continue", - "convert", - "create", - "cross", - "current", - "current_date", - "current_time", - "current_timestamp", - "current_user", - "cursor", - "database", - "dbcc", - "deallocate", - "declare", - "default", - "delete", - "deny", - "desc", - "disk", - "distinct", - "distributed", - "double", - "drop", - "dump", - "else", - "end", - "errlvl", - "escape", - "except", - "exec", - "execute", - "exists", - "exit", - "external", - "fetch", - "file", - "fillfactor", - "for", - "foreign", - "freetext", - "freetexttable", - "from", - "full", - "function", - "goto", - "grant", - "group", - "having", - "holdlock", - "identity", - "identity_insert", - "identitycol", - "if", - "in", - "index", - "inner", - "insert", - "intersect", - "into", - "is", - "join", - "key", - "kill", - "left", - "like", - "lineno", - "load", - "merge", - "national", - "nocheck", - "nonclustered", - "not", - "null", - "nullif", - "of", - "off", - "offsets", - "on", - "open", - "opendatasource", - "openquery", - "openrowset", - "openxml", - "option", - "or", - "order", - "outer", - "over", - "percent", - "pivot", - "plan", - "precision", - "primary", - "print", - "proc", - "procedure", - "public", - "raiserror", - "read", - "readtext", - "reconfigure", - "references", - "replication", - "restore", - "restrict", - "return", - "revert", - "revoke", - "right", - "rollback", - "rowcount", - "rowguidcol", - "rule", - "save", - "schema", - "securityaudit", - "select", - "session_user", - "set", - "setuser", - "shutdown", - "some", - "statistics", - "system_user", - "table", - "tablesample", - "textsize", - "then", - "to", - "top", - "tran", - "transaction", - "trigger", - "truncate", - "tsequal", - "union", - "unique", - "unpivot", - "update", - "updatetext", - "use", - "user", - "values", - "varying", - "view", - "waitfor", - "when", - "where", - "while", - "with", - "writetext", - ] -) +RESERVED_WORDS = { + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", +} class REAL(sqltypes.REAL): @@ -1159,7 +1157,7 @@ class REAL(sqltypes.REAL): # it is only accepted as the word "REAL" in DDL, the numeric # precision value is not allowed to be present kw.setdefault("precision", 24) - super(REAL, self).__init__(**kw) + super().__init__(**kw) class TINYINT(sqltypes.Integer): @@ -1204,7 +1202,7 @@ class _MSDate(sqltypes.Date): class TIME(sqltypes.TIME): def __init__(self, precision=None, **kwargs): self.precision = precision - super(TIME, self).__init__() + super().__init__() __zero_date = datetime.date(1900, 1, 1) @@ -1273,7 +1271,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime): __visit_name__ = "DATETIME2" def __init__(self, precision=None, **kw): - super(DATETIME2, self).__init__(**kw) + super().__init__(**kw) self.precision = precision @@ -1281,7 +1279,7 @@ class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime): __visit_name__ = "DATETIMEOFFSET" def __init__(self, precision=None, **kw): - super(DATETIMEOFFSET, self).__init__(**kw) + super().__init__(**kw) self.precision = precision @@ -1339,7 +1337,7 @@ class TIMESTAMP(sqltypes._Binary): self.convert_int = convert_int def result_processor(self, dialect, coltype): - super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) + super_ = super().result_processor(dialect, coltype) if self.convert_int: def process(value): @@ -1425,7 +1423,7 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): raise ValueError( "length must be None or 'max' when setting filestream" ) - super(VARBINARY, self).__init__(length=length) + super().__init__(length=length) class IMAGE(sqltypes.LargeBinary): @@ -1525,12 +1523,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): @overload def __init__( - self: "UNIQUEIDENTIFIER[_python_UUID]", as_uuid: Literal[True] = ... + self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... ): ... @overload - def __init__(self: "UNIQUEIDENTIFIER[str]", as_uuid: Literal[False] = ...): + def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): @@ -1972,7 +1970,7 @@ class MSExecutionContext(default.DefaultExecutionContext): and column.default.optional ): return None - return super(MSExecutionContext, self).get_insert_default(column) + return super().get_insert_default(column) class MSSQLCompiler(compiler.SQLCompiler): @@ -1990,7 +1988,7 @@ class MSSQLCompiler(compiler.SQLCompiler): def __init__(self, *args, **kwargs): self.tablealiases = {} - super(MSSQLCompiler, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def _with_legacy_schema_aliasing(fn): def decorate(self, *arg, **kw): @@ -2040,7 +2038,7 @@ class MSSQLCompiler(compiler.SQLCompiler): def get_select_precolumns(self, select, **kw): """MS-SQL puts TOP, it's version of LIMIT here""" - s = super(MSSQLCompiler, self).get_select_precolumns(select, **kw) + s = super().get_select_precolumns(select, **kw) if select._has_row_limiting_clause and self._use_top(select): # ODBC drivers and possibly others @@ -2186,20 +2184,20 @@ class MSSQLCompiler(compiler.SQLCompiler): @_with_legacy_schema_aliasing def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): if mssql_aliased is table or iscrud: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) + return super().visit_table(table, **kwargs) # alias schema-qualified tables alias = self._schema_aliased_table(table) if alias is not None: return self.process(alias, mssql_aliased=table, **kwargs) else: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) + return super().visit_table(table, **kwargs) @_with_legacy_schema_aliasing def visit_alias(self, alias, **kw): # translate for schema-qualified table aliases kw["mssql_aliased"] = alias.element - return super(MSSQLCompiler, self).visit_alias(alias, **kw) + return super().visit_alias(alias, **kw) @_with_legacy_schema_aliasing def visit_column(self, column, add_to_result_map=None, **kw): @@ -2220,9 +2218,9 @@ class MSSQLCompiler(compiler.SQLCompiler): column.type, ) - return super(MSSQLCompiler, self).visit_column(converted, **kw) + return super().visit_column(converted, **kw) - return super(MSSQLCompiler, self).visit_column( + return super().visit_column( column, add_to_result_map=add_to_result_map, **kw ) @@ -2264,7 +2262,7 @@ class MSSQLCompiler(compiler.SQLCompiler): ), **kwargs, ) - return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + return super().visit_binary(binary, **kwargs) def returning_clause( self, stmt, returning_cols, *, populate_result_map, **kw @@ -2328,9 +2326,7 @@ class MSSQLCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label(None) else: - return super(MSSQLCompiler, self).label_select_column( - select, column, asfrom - ) + return super().label_select_column(select, column, asfrom) def for_update_clause(self, select, **kw): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which @@ -2517,9 +2513,7 @@ class MSSQLStrictCompiler(MSSQLCompiler): # SQL Server wants single quotes around the date string. return "'" + str(value) + "'" else: - return super(MSSQLStrictCompiler, self).render_literal_value( - value, type_ - ) + return super().render_literal_value(value, type_) class MSDDLCompiler(compiler.DDLCompiler): @@ -2704,7 +2698,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_addextendedproperty 'MS_Description', " - "{0}, 'schema', {1}, 'table', {2}".format( + "{}, 'schema', {}, 'table', {}".format( self.sql_compiler.render_literal_value( create.element.comment, sqltypes.NVARCHAR() ), @@ -2718,7 +2712,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_dropextendedproperty 'MS_Description', 'schema', " - "{0}, 'table', {1}".format( + "{}, 'table', {}".format( self.preparer.quote_schema(schema_name), self.preparer.format_table(drop.element, use_schema=False), ) @@ -2729,7 +2723,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_addextendedproperty 'MS_Description', " - "{0}, 'schema', {1}, 'table', {2}, 'column', {3}".format( + "{}, 'schema', {}, 'table', {}, 'column', {}".format( self.sql_compiler.render_literal_value( create.element.comment, sqltypes.NVARCHAR() ), @@ -2746,7 +2740,7 @@ class MSDDLCompiler(compiler.DDLCompiler): schema_name = schema if schema else self.dialect.default_schema_name return ( "execute sp_dropextendedproperty 'MS_Description', 'schema', " - "{0}, 'table', {1}, 'column', {2}".format( + "{}, 'table', {}, 'column', {}".format( self.preparer.quote_schema(schema_name), self.preparer.format_table( drop.element.table, use_schema=False @@ -2760,9 +2754,7 @@ class MSDDLCompiler(compiler.DDLCompiler): if create.element.data_type is not None: data_type = create.element.data_type prefix = " AS %s" % self.type_compiler.process(data_type) - return super(MSDDLCompiler, self).visit_create_sequence( - create, prefix=prefix, **kw - ) + return super().visit_create_sequence(create, prefix=prefix, **kw) def visit_identity_column(self, identity, **kw): text = " IDENTITY" @@ -2777,7 +2769,7 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS def __init__(self, dialect): - super(MSIdentifierPreparer, self).__init__( + super().__init__( dialect, initial_quote="[", final_quote="]", @@ -3067,7 +3059,7 @@ class MSDialect(default.DefaultDialect): ) self.legacy_schema_aliasing = legacy_schema_aliasing - super(MSDialect, self).__init__(**opts) + super().__init__(**opts) self._json_serializer = json_serializer self._json_deserializer = json_deserializer @@ -3075,7 +3067,7 @@ class MSDialect(default.DefaultDialect): def do_savepoint(self, connection, name): # give the DBAPI a push connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") - super(MSDialect, self).do_savepoint(connection, name) + super().do_savepoint(connection, name) def do_release_savepoint(self, connection, name): # SQL Server does not support RELEASE SAVEPOINT @@ -3083,7 +3075,7 @@ class MSDialect(default.DefaultDialect): def do_rollback(self, dbapi_connection): try: - super(MSDialect, self).do_rollback(dbapi_connection) + super().do_rollback(dbapi_connection) except self.dbapi.ProgrammingError as e: if self.ignore_no_transaction_on_rollback and re.match( r".*\b111214\b", str(e) @@ -3097,15 +3089,13 @@ class MSDialect(default.DefaultDialect): else: raise - _isolation_lookup = set( - [ - "SERIALIZABLE", - "READ UNCOMMITTED", - "READ COMMITTED", - "REPEATABLE READ", - "SNAPSHOT", - ] - ) + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + } def get_isolation_level_values(self, dbapi_connection): return list(self._isolation_lookup) @@ -3134,7 +3124,7 @@ class MSDialect(default.DefaultDialect): "SQL Server version." ) - view_name = "sys.{}".format(row[0]) + view_name = f"sys.{row[0]}" cursor.execute( """ @@ -3164,7 +3154,7 @@ class MSDialect(default.DefaultDialect): cursor.close() def initialize(self, connection): - super(MSDialect, self).initialize(connection) + super().initialize(connection) self._setup_version_attributes() self._setup_supports_nvarchar_max(connection) @@ -3298,7 +3288,7 @@ class MSDialect(default.DefaultDialect): connection.scalar( # U filters on user tables only. text("SELECT object_id(:table_name, 'U')"), - {"table_name": "tempdb.dbo.[{}]".format(tablename)}, + {"table_name": f"tempdb.dbo.[{tablename}]"}, ) ) else: diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 96d03a908c..5d859765c5 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -42,7 +42,7 @@ class _MSNumeric_pymssql(sqltypes.Numeric): class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): def __init__(self, dialect): - super(MSIdentifierPreparer_pymssql, self).__init__(dialect) + super().__init__(dialect) # pymssql has the very unusual behavior that it uses pyformat # yet does not require that percent signs be doubled self._double_percents = False @@ -119,9 +119,7 @@ class MSDialect_pymssql(MSDialect): dbapi_connection.autocommit(True) else: dbapi_connection.autocommit(False) - super(MSDialect_pymssql, self).set_isolation_level( - dbapi_connection, level - ) + super().set_isolation_level(dbapi_connection, level) dialect = MSDialect_pymssql diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 3b8caef3b0..07cbe3a73e 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -384,7 +384,7 @@ class _ms_numeric_pyodbc: def bind_processor(self, dialect): - super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect) + super_process = super().bind_processor(dialect) if not dialect._need_decimal_fix: return super_process @@ -570,7 +570,7 @@ class MSExecutionContext_pyodbc(MSExecutionContext): """ - super(MSExecutionContext_pyodbc, self).pre_exec() + super().pre_exec() # don't embed the scope_identity select into an # "INSERT .. DEFAULT VALUES" @@ -601,7 +601,7 @@ class MSExecutionContext_pyodbc(MSExecutionContext): self._lastrowid = int(row[0]) else: - super(MSExecutionContext_pyodbc, self).post_exec() + super().post_exec() class MSDialect_pyodbc(PyODBCConnector, MSDialect): @@ -648,9 +648,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): use_setinputsizes=True, **params, ): - super(MSDialect_pyodbc, self).__init__( - use_setinputsizes=use_setinputsizes, **params - ) + super().__init__(use_setinputsizes=use_setinputsizes, **params) self.use_scope_identity = ( self.use_scope_identity and self.dbapi @@ -674,9 +672,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # SQL Server docs indicate this function isn't present prior to # 2008. Before we had the VARCHAR cast above, pyodbc would also # fail on this query. - return super(MSDialect_pyodbc, self)._get_server_version_info( - connection - ) + return super()._get_server_version_info(connection) else: version = [] r = re.compile(r"[.\-]") @@ -688,7 +684,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): return tuple(version) def on_connect(self): - super_ = super(MSDialect_pyodbc, self).on_connect() + super_ = super().on_connect() def on_connect(conn): if super_ is not None: @@ -723,9 +719,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): def do_executemany(self, cursor, statement, parameters, context=None): if self.fast_executemany: cursor.fast_executemany = True - super(MSDialect_pyodbc, self).do_executemany( - cursor, statement, parameters, context=context - ) + super().do_executemany(cursor, statement, parameters, context=context) def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): @@ -743,9 +737,7 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): "10054", }: return True - return super(MSDialect_pyodbc, self).is_disconnect( - e, connection, cursor - ) + return super().is_disconnect(e, connection, cursor) dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index 896c902272..79f865cf15 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -294,14 +294,12 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): return pool.AsyncAdaptedQueuePool def create_connect_args(self, url): - return super(MySQLDialect_aiomysql, self).create_connect_args( + return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_aiomysql, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True else: str_e = str(e).lower() diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index c8f29a2f12..df8965cbbd 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -304,14 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): return pool.AsyncAdaptedQueuePool def create_connect_args(self, url): - return super(MySQLDialect_asyncmy, self).create_connect_args( + return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_asyncmy, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True else: str_e = str(e).lower() diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e8ddb6d1e9..2525c6c32e 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1354,7 +1354,7 @@ class MySQLCompiler(compiler.SQLCompiler): name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) - non_matching = set(on_duplicate.update) - set(c.key for c in cols) + non_matching = set(on_duplicate.update) - {c.key for c in cols} if non_matching: util.warn( "Additional column names not matching " @@ -1503,7 +1503,7 @@ class MySQLCompiler(compiler.SQLCompiler): return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) def render_literal_value(self, value, type_): - value = super(MySQLCompiler, self).render_literal_value(value, type_) + value = super().render_literal_value(value, type_) if self.dialect._backslash_escapes: value = value.replace("\\", "\\\\") return value @@ -1534,7 +1534,7 @@ class MySQLCompiler(compiler.SQLCompiler): ) return select._distinct.upper() + " " - return super(MySQLCompiler, self).get_select_precolumns(select, **kw) + return super().get_select_precolumns(select, **kw) def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): if from_linter: @@ -1805,11 +1805,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts = [] - opts = dict( - (k[len(self.dialect.name) + 1 :].upper(), v) + opts = { + k[len(self.dialect.name) + 1 :].upper(): v for k, v in table.kwargs.items() if k.startswith("%s_" % self.dialect.name) - ) + } if table.comment is not None: opts["COMMENT"] = table.comment @@ -1963,9 +1963,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return text def visit_primary_key_constraint(self, constraint): - text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( - constraint - ) + text = super().visit_primary_key_constraint(constraint) using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) @@ -2305,7 +2303,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_enum(self, type_, **kw): if not type_.native_enum: - return super(MySQLTypeCompiler, self).visit_enum(type_) + return super().visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) @@ -2351,9 +2349,7 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): else: quote = '"' - super(MySQLIdentifierPreparer, self).__init__( - dialect, initial_quote=quote, escape_quote=quote - ) + super().__init__(dialect, initial_quote=quote, escape_quote=quote) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index 8dc96fb154..3504588772 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -84,7 +84,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): if elem == "": return elem else: - return super(ENUM, self)._object_value_for_elem(elem) + return super()._object_value_for_elem(elem) def __repr__(self): return util.generic_repr( @@ -153,15 +153,15 @@ class SET(_StringType): "setting retrieve_as_bitwise=True" ) if self.retrieve_as_bitwise: - self._bitmap = dict( - (value, 2**idx) for idx, value in enumerate(self.values) - ) + self._bitmap = { + value: 2**idx for idx, value in enumerate(self.values) + } self._bitmap.update( (2**idx, value) for idx, value in enumerate(self.values) ) length = max([len(v) for v in values] + [0]) kw.setdefault("length", length) - super(SET, self).__init__(**kw) + super().__init__(**kw) def column_expression(self, colexpr): if self.retrieve_as_bitwise: @@ -183,7 +183,7 @@ class SET(_StringType): return None else: - super_convert = super(SET, self).result_processor(dialect, coltype) + super_convert = super().result_processor(dialect, coltype) def process(value): if isinstance(value, str): @@ -201,7 +201,7 @@ class SET(_StringType): return process def bind_processor(self, dialect): - super_convert = super(SET, self).bind_processor(dialect) + super_convert = super().bind_processor(dialect) if self.retrieve_as_bitwise: def process(value): diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py index c8c6935174..561803a78d 100644 --- a/lib/sqlalchemy/dialects/mysql/expression.py +++ b/lib/sqlalchemy/dialects/mysql/expression.py @@ -107,9 +107,7 @@ class match(Generative, elements.BinaryExpression): if kw: raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw))) - super(match, self).__init__( - left, against, operators.match_op, modifiers=flags - ) + super().__init__(left, against, operators.match_op, modifiers=flags) @_generative def in_boolean_mode(self: Selfmatch) -> Selfmatch: diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index 6327d8687d..a3f288cebc 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -102,7 +102,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): return (99, 99, 99) def __init__(self, **kwargs): - super(MySQLDialect_mariadbconnector, self).__init__(**kwargs) + super().__init__(**kwargs) self.paramstyle = "qmark" if self.dbapi is not None: if self._dbapi_version < mariadb_cpy_minimum_version: @@ -117,9 +117,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): return __import__("mariadb") def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_mariadbconnector, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True elif isinstance(e, self.dbapi.Error): str_e = str(e).lower() @@ -188,9 +186,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): connection.autocommit = True else: connection.autocommit = False - super(MySQLDialect_mariadbconnector, self).set_isolation_level( - connection, level - ) + super().set_isolation_level(connection, level) def do_begin_twophase(self, connection, xid): connection.execute( diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 58e92c4ab7..f29a5008cd 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -167,24 +167,20 @@ class MySQLDialect_mysqlconnector(MySQLDialect): def _compat_fetchone(self, rp, charset=None): return rp.fetchone() - _isolation_lookup = set( - [ - "SERIALIZABLE", - "READ UNCOMMITTED", - "READ COMMITTED", - "REPEATABLE READ", - "AUTOCOMMIT", - ] - ) + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + } def _set_isolation_level(self, connection, level): if level == "AUTOCOMMIT": connection.autocommit = True else: connection.autocommit = False - super(MySQLDialect_mysqlconnector, self)._set_isolation_level( - connection, level - ) + super()._set_isolation_level(connection, level) dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 60b9cb1035..9eb1ef84ad 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -137,7 +137,7 @@ class MySQLDialect_mysqldb(MySQLDialect): preparer = MySQLIdentifierPreparer def __init__(self, **kwargs): - super(MySQLDialect_mysqldb, self).__init__(**kwargs) + super().__init__(**kwargs) self._mysql_dbapi_version = ( self._parse_dbapi_version(self.dbapi.__version__) if self.dbapi is not None and hasattr(self.dbapi, "__version__") @@ -165,7 +165,7 @@ class MySQLDialect_mysqldb(MySQLDialect): return __import__("MySQLdb") def on_connect(self): - super_ = super(MySQLDialect_mysqldb, self).on_connect() + super_ = super().on_connect() def on_connect(conn): if super_ is not None: @@ -221,9 +221,7 @@ class MySQLDialect_mysqldb(MySQLDialect): ] else: additional_tests = [] - return super(MySQLDialect_mysqldb, self)._check_unicode_returns( - connection, additional_tests - ) + return super()._check_unicode_returns(connection, additional_tests) def create_connect_args(self, url, _translate_args=None): if _translate_args is None: @@ -324,9 +322,7 @@ class MySQLDialect_mysqldb(MySQLDialect): dbapi_connection.autocommit(True) else: dbapi_connection.autocommit(False) - super(MySQLDialect_mysqldb, self).set_isolation_level( - dbapi_connection, level - ) + super().set_isolation_level(dbapi_connection, level) dialect = MySQLDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 66d2f32420..8a194d7fbb 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -65,14 +65,12 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): def create_connect_args(self, url, _translate_args=None): if _translate_args is None: _translate_args = dict(username="user") - return super(MySQLDialect_pymysql, self).create_connect_args( + return super().create_connect_args( url, _translate_args=_translate_args ) def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_pymysql, self).is_disconnect( - e, connection, cursor - ): + if super().is_disconnect(e, connection, cursor): return True elif isinstance(e, self.dbapi.Error): str_e = str(e).lower() diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 2d31dfe5fb..f9464f39f0 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -118,7 +118,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): return None def on_connect(self): - super_ = super(MySQLDialect_pyodbc, self).on_connect() + super_ = super().on_connect() def on_connect(conn): if super_ is not None: diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index 44bc62179d..fa1b7e0b76 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -340,9 +340,9 @@ class MySQLTableDefinitionParser: buffer = [] for row in columns: - (name, col_type, nullable, default, extra) = [ + (name, col_type, nullable, default, extra) = ( row[i] for i in (0, 1, 2, 4, 5) - ] + ) line = [" "] line.append(self.preparer.quote_identifier(name)) diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index a74fba1771..5a96b890bb 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -25,7 +25,7 @@ class _NumericType: def __init__(self, unsigned=False, zerofill=False, **kw): self.unsigned = unsigned self.zerofill = zerofill - super(_NumericType, self).__init__(**kw) + super().__init__(**kw) def __repr__(self): return util.generic_repr( @@ -43,9 +43,7 @@ class _FloatType(_NumericType, sqltypes.Float): "You must specify both precision and scale or omit " "both altogether." ) - super(_FloatType, self).__init__( - precision=precision, asdecimal=asdecimal, **kw - ) + super().__init__(precision=precision, asdecimal=asdecimal, **kw) self.scale = scale def __repr__(self): @@ -57,7 +55,7 @@ class _FloatType(_NumericType, sqltypes.Float): class _IntegerType(_NumericType, sqltypes.Integer): def __init__(self, display_width=None, **kw): self.display_width = display_width - super(_IntegerType, self).__init__(**kw) + super().__init__(**kw) def __repr__(self): return util.generic_repr( @@ -87,7 +85,7 @@ class _StringType(sqltypes.String): self.unicode = unicode self.binary = binary self.national = national - super(_StringType, self).__init__(**kw) + super().__init__(**kw) def __repr__(self): return util.generic_repr( @@ -123,7 +121,7 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC): numeric. """ - super(NUMERIC, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -149,7 +147,7 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL): numeric. """ - super(DECIMAL, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -183,7 +181,7 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE): numeric. """ - super(DOUBLE, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -217,7 +215,7 @@ class REAL(_FloatType, sqltypes.REAL): numeric. """ - super(REAL, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -243,7 +241,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT): numeric. """ - super(FLOAT, self).__init__( + super().__init__( precision=precision, scale=scale, asdecimal=asdecimal, **kw ) @@ -269,7 +267,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER): numeric. """ - super(INTEGER, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class BIGINT(_IntegerType, sqltypes.BIGINT): @@ -290,7 +288,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT): numeric. """ - super(BIGINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class MEDIUMINT(_IntegerType): @@ -311,7 +309,7 @@ class MEDIUMINT(_IntegerType): numeric. """ - super(MEDIUMINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class TINYINT(_IntegerType): @@ -332,7 +330,7 @@ class TINYINT(_IntegerType): numeric. """ - super(TINYINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class SMALLINT(_IntegerType, sqltypes.SMALLINT): @@ -353,7 +351,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT): numeric. """ - super(SMALLINT, self).__init__(display_width=display_width, **kw) + super().__init__(display_width=display_width, **kw) class BIT(sqltypes.TypeEngine): @@ -417,7 +415,7 @@ class TIME(sqltypes.TIME): MySQL Connector/Python. """ - super(TIME, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.fsp = fsp def result_processor(self, dialect, coltype): @@ -462,7 +460,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): MySQL Connector/Python. """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.fsp = fsp @@ -487,7 +485,7 @@ class DATETIME(sqltypes.DATETIME): MySQL Connector/Python. """ - super(DATETIME, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.fsp = fsp @@ -533,7 +531,7 @@ class TEXT(_StringType, sqltypes.TEXT): only the collation of character data. """ - super(TEXT, self).__init__(length=length, **kw) + super().__init__(length=length, **kw) class TINYTEXT(_StringType): @@ -565,7 +563,7 @@ class TINYTEXT(_StringType): only the collation of character data. """ - super(TINYTEXT, self).__init__(**kwargs) + super().__init__(**kwargs) class MEDIUMTEXT(_StringType): @@ -597,7 +595,7 @@ class MEDIUMTEXT(_StringType): only the collation of character data. """ - super(MEDIUMTEXT, self).__init__(**kwargs) + super().__init__(**kwargs) class LONGTEXT(_StringType): @@ -629,7 +627,7 @@ class LONGTEXT(_StringType): only the collation of character data. """ - super(LONGTEXT, self).__init__(**kwargs) + super().__init__(**kwargs) class VARCHAR(_StringType, sqltypes.VARCHAR): @@ -661,7 +659,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): only the collation of character data. """ - super(VARCHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class CHAR(_StringType, sqltypes.CHAR): @@ -682,7 +680,7 @@ class CHAR(_StringType, sqltypes.CHAR): compatible with the national character set. """ - super(CHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) @classmethod def _adapt_string_for_cast(self, type_): @@ -728,7 +726,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): """ kwargs["national"] = True - super(NVARCHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class NCHAR(_StringType, sqltypes.NCHAR): @@ -754,7 +752,7 @@ class NCHAR(_StringType, sqltypes.NCHAR): """ kwargs["national"] = True - super(NCHAR, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class TINYBLOB(sqltypes._Binary): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 6481ae4838..0d51bf73d5 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -813,7 +813,7 @@ class OracleCompiler(compiler.SQLCompiler): def __init__(self, *args, **kwargs): self.__wheres = {} - super(OracleCompiler, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % ( @@ -852,15 +852,13 @@ class OracleCompiler(compiler.SQLCompiler): return "" def visit_function(self, func, **kw): - text = super(OracleCompiler, self).visit_function(func, **kw) + text = super().visit_function(func, **kw) if kw.get("asfrom", False): text = "TABLE (%s)" % func return text def visit_table_valued_column(self, element, **kw): - text = super(OracleCompiler, self).visit_table_valued_column( - element, **kw - ) + text = super().visit_table_valued_column(element, **kw) text = "COLUMN_VALUE " + text return text @@ -1331,9 +1329,7 @@ class OracleDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) def get_identity_options(self, identity_options): - text = super(OracleDDLCompiler, self).get_identity_options( - identity_options - ) + text = super().get_identity_options(identity_options) text = text.replace("NO MINVALUE", "NOMINVALUE") text = text.replace("NO MAXVALUE", "NOMAXVALUE") text = text.replace("NO CYCLE", "NOCYCLE") @@ -1386,9 +1382,7 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): def format_savepoint(self, savepoint): name = savepoint.ident.lstrip("_") - return super(OracleIdentifierPreparer, self).format_savepoint( - savepoint, name - ) + return super().format_savepoint(savepoint, name) class OracleExecutionContext(default.DefaultExecutionContext): @@ -1489,7 +1483,7 @@ class OracleDialect(default.DefaultDialect): ) = enable_offset_fetch def initialize(self, connection): - super(OracleDialect, self).initialize(connection) + super().initialize(connection) # Oracle 8i has RETURNING: # https://docs.oracle.com/cd/A87860_01/doc/index.htm diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 5a0c0e160a..0be309cd4b 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -672,9 +672,7 @@ class _OracleBinary(_LOBDataType, sqltypes.LargeBinary): if not dialect.auto_convert_lobs: return None else: - return super(_OracleBinary, self).result_processor( - dialect, coltype - ) + return super().result_processor(dialect, coltype) class _OracleInterval(oracle.INTERVAL): diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index 60a8ebcb50..5cea62b9f8 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -35,12 +35,10 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): if asdecimal is None: asdecimal = bool(scale and scale > 0) - super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal - ) + super().__init__(precision=precision, scale=scale, asdecimal=asdecimal) def adapt(self, impltype): - ret = super(NUMBER, self).adapt(impltype) + ret = super().adapt(impltype) # leave a hint for the DBAPI handler ret._is_oracle_number = True return ret diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index 92341d2dac..4bb1026a56 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -60,15 +60,13 @@ class _PsycopgHStore(HSTORE): if dialect._has_native_hstore: return None else: - return super(_PsycopgHStore, self).bind_processor(dialect) + return super().bind_processor(dialect) def result_processor(self, dialect, coltype): if dialect._has_native_hstore: return None else: - return super(_PsycopgHStore, self).result_processor( - dialect, coltype - ) + return super().result_processor(dialect, coltype) class _PsycopgARRAY(PGARRAY): diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 3132e875e6..e130eccc2f 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -101,7 +101,7 @@ class array(expression.ExpressionClauseList[_T]): def __init__(self, clauses, **kw): type_arg = kw.pop("type_", None) - super(array, self).__init__(operators.comma_op, *clauses, **kw) + super().__init__(operators.comma_op, *clauses, **kw) self._type_tuple = [arg.type for arg in self.clauses] diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index cd161d28e0..751dc3dcf3 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -560,7 +560,7 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): __slots__ = ("_rowbuffer",) def __init__(self, adapt_connection): - super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection) + super().__init__(adapt_connection) self._rowbuffer = None def close(self): @@ -863,9 +863,7 @@ class AsyncAdapt_asyncpg_dbapi: class InvalidCachedStatementError(NotSupportedError): def __init__(self, message): - super( - AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self - ).__init__( + super().__init__( message + " (SQLAlchemy asyncpg dialect will now invalidate " "all prepared caches in response to this exception)", ) @@ -1095,7 +1093,7 @@ class PGDialect_asyncpg(PGDialect): """ - super_connect = super(PGDialect_asyncpg, self).on_connect() + super_connect = super().on_connect() def connect(conn): conn.await_(self.setup_asyncpg_json_codec(conn)) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index a908ed6b78..49ee89daac 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1469,112 +1469,110 @@ from ...util.typing import TypedDict IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) -RESERVED_WORDS = set( - [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "leading", - "limit", - "localtime", - "localtimestamp", - "new", - "not", - "null", - "of", - "off", - "offset", - "old", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", - "authorization", - "between", - "binary", - "cross", - "current_schema", - "freeze", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "notnull", - "outer", - "over", - "overlaps", - "right", - "similar", - "verbose", - ] -) +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", +} colspecs = { sqltypes.ARRAY: _array.ARRAY, @@ -1801,7 +1799,7 @@ class PGCompiler(compiler.SQLCompiler): ) def render_literal_value(self, value, type_): - value = super(PGCompiler, self).render_literal_value(value, type_) + value = super().render_literal_value(value, type_) if self.dialect._backslash_escapes: value = value.replace("\\", "\\\\") @@ -2108,14 +2106,12 @@ class PGDDLCompiler(compiler.DDLCompiler): "create_constraint=False on this Enum datatype." ) - text = super(PGDDLCompiler, self).visit_check_constraint(constraint) + text = super().visit_check_constraint(constraint) text += self._define_constraint_validity(constraint) return text def visit_foreign_key_constraint(self, constraint): - text = super(PGDDLCompiler, self).visit_foreign_key_constraint( - constraint - ) + text = super().visit_foreign_key_constraint(constraint) text += self._define_constraint_validity(constraint) return text @@ -2353,9 +2349,7 @@ class PGDDLCompiler(compiler.DDLCompiler): create.element.data_type ) - return super(PGDDLCompiler, self).visit_create_sequence( - create, prefix=prefix, **kw - ) + return super().visit_create_sequence(create, prefix=prefix, **kw) def _can_comment_on_constraint(self, ddl_instance): constraint = ddl_instance.element @@ -2478,7 +2472,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_enum(self, type_, **kw): if not type_.native_enum or not self.dialect.supports_native_enum: - return super(PGTypeCompiler, self).visit_enum(type_, **kw) + return super().visit_enum(type_, **kw) else: return self.visit_ENUM(type_, **kw) @@ -2803,7 +2797,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self._execute_scalar(exc, column.type) - return super(PGExecutionContext, self).get_insert_default(column) + return super().get_insert_default(column) class PGReadOnlyConnectionCharacteristic( @@ -2945,7 +2939,7 @@ class PGDialect(default.DefaultDialect): self._json_serializer = json_serializer def initialize(self, connection): - super(PGDialect, self).initialize(connection) + super().initialize(connection) # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index b79b4a30ec..645bedf177 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -266,7 +266,7 @@ class OnConflictDoUpdate(OnConflictClause): set_=None, where=None, ): - super(OnConflictDoUpdate, self).__init__( + super().__init__( constraint=constraint, index_elements=index_elements, index_where=index_where, diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index ebaad27342..b0d8ef3457 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -243,7 +243,7 @@ class ExcludeConstraint(ColumnCollectionConstraint): self.ops = kw.get("ops", {}) def _set_parent(self, table, **kw): - super(ExcludeConstraint, self)._set_parent(table) + super()._set_parent(table) self._render_exprs = [ ( diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index a8b03bd482..c68671918c 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -221,7 +221,7 @@ class JSON(sqltypes.JSON): .. versionadded:: 1.1 """ - super(JSON, self).__init__(none_as_null=none_as_null) + super().__init__(none_as_null=none_as_null) if astext_type is not None: self.astext_type = astext_type diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index f844e92130..79b567f089 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -31,8 +31,8 @@ class NamedType(sqltypes.TypeEngine): """Base for named types.""" __abstract__ = True - DDLGenerator: Type["NamedTypeGenerator"] - DDLDropper: Type["NamedTypeDropper"] + DDLGenerator: Type[NamedTypeGenerator] + DDLDropper: Type[NamedTypeDropper] create_type: bool def create(self, bind, checkfirst=True, **kw): diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index cb5cab178e..5acd50710f 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -651,7 +651,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): ) def initialize(self, connection): - super(PGDialect_psycopg2, self).initialize(connection) + super().initialize(connection) self._has_native_hstore = ( self.use_native_hstore and self._hstore_oids(connection.connection.dbapi_connection) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 81b6771872..72703ff814 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -128,7 +128,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): .. versionadded:: 1.4 """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.precision = precision @@ -147,7 +147,7 @@ class TIME(sqltypes.TIME): .. versionadded:: 1.4 """ - super(TIME, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) self.precision = precision diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c3cb10cefa..4e5808f623 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -921,9 +921,7 @@ from ...types import VARCHAR # noqa class _SQliteJson(JSON): def result_processor(self, dialect, coltype): - default_processor = super(_SQliteJson, self).result_processor( - dialect, coltype - ) + default_processor = super().result_processor(dialect, coltype) def process(value): try: @@ -942,7 +940,7 @@ class _DateTimeMixin: _storage_format = None def __init__(self, storage_format=None, regexp=None, **kw): - super(_DateTimeMixin, self).__init__(**kw) + super().__init__(**kw) if regexp is not None: self._reg = re.compile(regexp) if storage_format is not None: @@ -978,7 +976,7 @@ class _DateTimeMixin: kw["storage_format"] = self._storage_format if self._reg: kw["regexp"] = self._reg - return super(_DateTimeMixin, self).adapt(cls, **kw) + return super().adapt(cls, **kw) def literal_processor(self, dialect): bp = self.bind_processor(dialect) @@ -1037,7 +1035,7 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): def __init__(self, *args, **kwargs): truncate_microseconds = kwargs.pop("truncate_microseconds", False) - super(DATETIME, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if truncate_microseconds: assert "storage_format" not in kwargs, ( "You can specify only " @@ -1215,7 +1213,7 @@ class TIME(_DateTimeMixin, sqltypes.Time): def __init__(self, *args, **kwargs): truncate_microseconds = kwargs.pop("truncate_microseconds", False) - super(TIME, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if truncate_microseconds: assert "storage_format" not in kwargs, ( "You can specify only " @@ -1337,7 +1335,7 @@ class SQLiteCompiler(compiler.SQLCompiler): def visit_cast(self, cast, **kwargs): if self.dialect.supports_cast: - return super(SQLiteCompiler, self).visit_cast(cast, **kwargs) + return super().visit_cast(cast, **kwargs) else: return self.process(cast.clause, **kwargs) @@ -1610,9 +1608,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): ): return None - text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint( - constraint - ) + text = super().visit_primary_key_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ "on_conflict" @@ -1628,9 +1624,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_unique_constraint(self, constraint): - text = super(SQLiteDDLCompiler, self).visit_unique_constraint( - constraint - ) + text = super().visit_unique_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ "on_conflict" @@ -1648,9 +1642,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_check_constraint(self, constraint): - text = super(SQLiteDDLCompiler, self).visit_check_constraint( - constraint - ) + text = super().visit_check_constraint(constraint) on_conflict_clause = constraint.dialect_options["sqlite"][ "on_conflict" @@ -1662,9 +1654,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_column_check_constraint(self, constraint): - text = super(SQLiteDDLCompiler, self).visit_column_check_constraint( - constraint - ) + text = super().visit_column_check_constraint(constraint) if constraint.dialect_options["sqlite"]["on_conflict"] is not None: raise exc.CompileError( @@ -1682,9 +1672,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if local_table.schema != remote_table.schema: return None else: - return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( - constraint - ) + return super().visit_foreign_key_constraint(constraint) def define_constraint_remote_table(self, constraint, table, preparer): """Format the remote table clause of a CREATE CONSTRAINT clause.""" @@ -1741,7 +1729,7 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): not isinstance(type_, _DateTimeMixin) or type_.format_is_text_affinity ): - return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) + return super().visit_DATETIME(type_) else: return "DATETIME_CHAR" @@ -1750,7 +1738,7 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): not isinstance(type_, _DateTimeMixin) or type_.format_is_text_affinity ): - return super(SQLiteTypeCompiler, self).visit_DATE(type_) + return super().visit_DATE(type_) else: return "DATE_CHAR" @@ -1759,7 +1747,7 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): not isinstance(type_, _DateTimeMixin) or type_.format_is_text_affinity ): - return super(SQLiteTypeCompiler, self).visit_TIME(type_) + return super().visit_TIME(type_) else: return "TIME_CHAR" @@ -1771,127 +1759,125 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set( - [ - "add", - "after", - "all", - "alter", - "analyze", - "and", - "as", - "asc", - "attach", - "autoincrement", - "before", - "begin", - "between", - "by", - "cascade", - "case", - "cast", - "check", - "collate", - "column", - "commit", - "conflict", - "constraint", - "create", - "cross", - "current_date", - "current_time", - "current_timestamp", - "database", - "default", - "deferrable", - "deferred", - "delete", - "desc", - "detach", - "distinct", - "drop", - "each", - "else", - "end", - "escape", - "except", - "exclusive", - "exists", - "explain", - "false", - "fail", - "for", - "foreign", - "from", - "full", - "glob", - "group", - "having", - "if", - "ignore", - "immediate", - "in", - "index", - "indexed", - "initially", - "inner", - "insert", - "instead", - "intersect", - "into", - "is", - "isnull", - "join", - "key", - "left", - "like", - "limit", - "match", - "natural", - "not", - "notnull", - "null", - "of", - "offset", - "on", - "or", - "order", - "outer", - "plan", - "pragma", - "primary", - "query", - "raise", - "references", - "reindex", - "rename", - "replace", - "restrict", - "right", - "rollback", - "row", - "select", - "set", - "table", - "temp", - "temporary", - "then", - "to", - "transaction", - "trigger", - "true", - "union", - "unique", - "update", - "using", - "vacuum", - "values", - "view", - "virtual", - "when", - "where", - ] - ) + reserved_words = { + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "exists", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + } class SQLiteExecutionContext(default.DefaultExecutionContext): @@ -2454,17 +2440,14 @@ class SQLiteDialect(default.DefaultDialect): # the names as well. SQLite saves the DDL in whatever format # it was typed in as, so need to be liberal here. - keys_by_signature = dict( - ( - fk_sig( - fk["constrained_columns"], - fk["referred_table"], - fk["referred_columns"], - ), - fk, - ) + keys_by_signature = { + fk_sig( + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ): fk for fk in fks.values() - ) + } table_data = self._get_table_sql(connection, table_name, schema=schema) diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 9e9e68330f..0777c92618 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -202,7 +202,7 @@ class OnConflictDoUpdate(OnConflictClause): set_=None, where=None, ): - super(OnConflictDoUpdate, self).__init__( + super().__init__( index_elements=index_elements, index_where=index_where, ) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 53e4b0d1bf..5c07f487c8 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -124,9 +124,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): return pool.SingletonThreadPool def on_connect_url(self, url): - super_on_connect = super( - SQLiteDialect_pysqlcipher, self - ).on_connect_url(url) + super_on_connect = super().on_connect_url(url) # pull the info we need from the URL early. Even though URL # is immutable, we don't want any in-place changes to the URL @@ -151,9 +149,7 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): def create_connect_args(self, url): plain_url = url._replace(password=None) plain_url = plain_url.difference_update_query(self.pragmas) - return super(SQLiteDialect_pysqlcipher, self).create_connect_args( - plain_url - ) + return super().create_connect_args(plain_url) dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 19949441fb..4475ccae7a 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -536,9 +536,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): dbapi_connection.isolation_level = None else: dbapi_connection.isolation_level = "" - return super(SQLiteDialect_pysqlite, self).set_isolation_level( - dbapi_connection, level - ) + return super().set_isolation_level(dbapi_connection, level) def on_connect(self): def regexp(a, b): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b2f6b29b78..b686de0d6a 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2854,7 +2854,7 @@ class TwoPhaseTransaction(RootTransaction): def __init__(self, connection: Connection, xid: Any): self._is_prepared = False self.xid = xid - super(TwoPhaseTransaction, self).__init__(connection) + super().__init__(connection) def prepare(self) -> None: """Prepare this :class:`.TwoPhaseTransaction`. diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 1ad8c90e76..c8736392f0 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -115,7 +115,7 @@ def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: "is deprecated and will be removed in a future release. ", ), ) -def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> Engine: +def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: """Create a new :class:`_engine.Engine` instance. The standard calling form is to send the :ref:`URL ` as the @@ -806,11 +806,11 @@ def engine_from_config( """ - options = dict( - (key[len(prefix) :], configuration[key]) + options = { + key[len(prefix) :]: configuration[key] for key in configuration if key.startswith(prefix) - ) + } options["_coerce_config"] = True options.update(kwargs) url = options.pop("url") diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index f22e89fbeb..33ee7866cd 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1230,15 +1230,11 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): def soft_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(BufferedRowCursorFetchStrategy, self).soft_close( - result, dbapi_cursor - ) + super().soft_close(result, dbapi_cursor) def hard_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(BufferedRowCursorFetchStrategy, self).hard_close( - result, dbapi_cursor - ) + super().hard_close(result, dbapi_cursor) def fetchone(self, result, dbapi_cursor, hard_close=False): if not self._rowbuffer: @@ -1307,15 +1303,11 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): def soft_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(FullyBufferedCursorFetchStrategy, self).soft_close( - result, dbapi_cursor - ) + super().soft_close(result, dbapi_cursor) def hard_close(self, result, dbapi_cursor): self._rowbuffer.clear() - super(FullyBufferedCursorFetchStrategy, self).hard_close( - result, dbapi_cursor - ) + super().hard_close(result, dbapi_cursor) def fetchone(self, result, dbapi_cursor, hard_close=False): if self._rowbuffer: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index e5d613dd58..3cc9cab8b3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -200,9 +200,7 @@ class DefaultDialect(Dialect): supports_sane_rowcount = True supports_sane_multi_rowcount = True - colspecs: MutableMapping[ - Type["TypeEngine[Any]"], Type["TypeEngine[Any]"] - ] = {} + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] = {} default_paramstyle = "named" supports_default_values = False @@ -1486,21 +1484,17 @@ class DefaultExecutionContext(ExecutionContext): use_server_side = self.execution_options.get( "stream_results", True ) and ( - ( - self.compiled - and isinstance( - self.compiled.statement, expression.Selectable - ) - or ( - ( - not self.compiled - or isinstance( - self.compiled.statement, expression.TextClause - ) + self.compiled + and isinstance(self.compiled.statement, expression.Selectable) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause ) - and self.unicode_statement - and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement) ) + and self.unicode_statement + and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement) ) ) else: @@ -1938,15 +1932,12 @@ class DefaultExecutionContext(ExecutionContext): ] ) else: - parameters = dict( - ( - key, - processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key], - ) + parameters = { + key: processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] for key in compiled_params - ) + } return self._execute_scalar( str(compiled), type_, parameters=parameters ) diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index e10fab831d..2f5efce259 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -757,7 +757,7 @@ class Dialect(EventTarget): # create_engine() -> isolation_level currently goes here _on_connect_isolation_level: Optional[IsolationLevel] - execution_ctx_cls: Type["ExecutionContext"] + execution_ctx_cls: Type[ExecutionContext] """a :class:`.ExecutionContext` class used to handle statement execution""" execute_sequence_format: Union[ @@ -963,7 +963,7 @@ class Dialect(EventTarget): """target database, when given a CTE with an INSERT statement, needs the CTE to be below the INSERT""" - colspecs: MutableMapping[Type["TypeEngine[Any]"], Type["TypeEngine[Any]"]] + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] """A dictionary of TypeEngine classes from sqlalchemy.types mapped to subclasses that are specific to the dialect class. This dictionary is class-level only and is not accessed from the @@ -1160,12 +1160,12 @@ class Dialect(EventTarget): _bind_typing_render_casts: bool - _type_memos: MutableMapping[TypeEngine[Any], "_TypeMemoDict"] + _type_memos: MutableMapping[TypeEngine[Any], _TypeMemoDict] def _builtin_onconnect(self) -> Optional[_ListenerFnType]: raise NotImplementedError() - def create_connect_args(self, url: "URL") -> ConnectArgsType: + def create_connect_args(self, url: URL) -> ConnectArgsType: """Build DB-API compatible connection arguments. Given a :class:`.URL` object, returns a tuple @@ -1217,7 +1217,7 @@ class Dialect(EventTarget): raise NotImplementedError() @classmethod - def type_descriptor(cls, typeobj: "TypeEngine[_T]") -> "TypeEngine[_T]": + def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: """Transform a generic type to a dialect-specific type. Dialect classes will usually use the @@ -2155,7 +2155,7 @@ class Dialect(EventTarget): self, cursor: DBAPICursor, statement: str, - context: Optional["ExecutionContext"] = None, + context: Optional[ExecutionContext] = None, ) -> None: """Provide an implementation of ``cursor.execute(statement)``. @@ -2210,7 +2210,7 @@ class Dialect(EventTarget): """ raise NotImplementedError() - def on_connect_url(self, url: "URL") -> Optional[Callable[[Any], Any]]: + def on_connect_url(self, url: URL) -> Optional[Callable[[Any], Any]]: """return a callable which sets up a newly created DBAPI connection. This method is a new hook that supersedes the @@ -2556,7 +2556,7 @@ class Dialect(EventTarget): """ @classmethod - def engine_created(cls, engine: "Engine") -> None: + def engine_created(cls, engine: Engine) -> None: """A convenience hook called before returning the final :class:`_engine.Engine`. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index f744d53ad6..d1669cc3cf 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -568,9 +568,9 @@ class Inspector(inspection.Inspectable["Inspector"]): schema_fkeys = self.get_multi_foreign_keys(schname, **kw) tnames.extend(schema_fkeys) for (_, tname), fkeys in schema_fkeys.items(): - fknames_for_table[(schname, tname)] = set( - [fk["name"] for fk in fkeys] - ) + fknames_for_table[(schname, tname)] = { + fk["name"] for fk in fkeys + } for fkey in fkeys: if ( tname != fkey["referred_table"] @@ -1517,11 +1517,11 @@ class Inspector(inspection.Inspectable["Inspector"]): # intended for reflection, e.g. oracle_resolve_synonyms. # these are unconditionally passed to related Table # objects - reflection_options = dict( - (k, table.dialect_kwargs.get(k)) + reflection_options = { + k: table.dialect_kwargs.get(k) for k in dialect.reflection_options if k in table.dialect_kwargs - ) + } table_key = (schema, table_name) if _reflect_info is None or table_key not in _reflect_info.columns: @@ -1644,8 +1644,8 @@ class Inspector(inspection.Inspectable["Inspector"]): coltype = col_d["type"] - col_kw = dict( - (k, col_d[k]) # type: ignore[literal-required] + col_kw = { + k: col_d[k] # type: ignore[literal-required] for k in [ "nullable", "autoincrement", @@ -1655,7 +1655,7 @@ class Inspector(inspection.Inspectable["Inspector"]): "comment", ] if k in col_d - ) + } if "dialect_options" in col_d: col_kw.update(col_d["dialect_options"]) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index bcd2f0ea9d..392cefa020 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -2342,7 +2342,7 @@ class ChunkedIteratorResult(IteratorResult[_TP]): return self def _soft_close(self, hard: bool = False, **kw: Any) -> None: - super(ChunkedIteratorResult, self)._soft_close(hard=hard, **kw) + super()._soft_close(hard=hard, **kw) self.chunks = lambda size: [] # type: ignore def _fetchmany_impl( @@ -2370,7 +2370,7 @@ class MergedResult(IteratorResult[_TP]): self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] ): self._results = results - super(MergedResult, self).__init__( + super().__init__( cursor_metadata, itertools.chain.from_iterable( r._raw_row_iterator() for r in results diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 3d43414104..a06b05940f 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -58,7 +58,7 @@ def _legacy_signature( def _wrap_fn_for_legacy( - dispatch_collection: "_ClsLevelDispatch[_ET]", + dispatch_collection: _ClsLevelDispatch[_ET], fn: _ListenerFnType, argspec: FullArgSpec, ) -> _ListenerFnType: @@ -120,7 +120,7 @@ def _indent(text: str, indent: str) -> str: def _standard_listen_example( - dispatch_collection: "_ClsLevelDispatch[_ET]", + dispatch_collection: _ClsLevelDispatch[_ET], sample_target: Any, fn: _ListenerFnType, ) -> str: @@ -161,7 +161,7 @@ def _standard_listen_example( def _legacy_listen_examples( - dispatch_collection: "_ClsLevelDispatch[_ET]", + dispatch_collection: _ClsLevelDispatch[_ET], sample_target: str, fn: _ListenerFnType, ) -> str: @@ -189,8 +189,8 @@ def _legacy_listen_examples( def _version_signature_changes( - parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], - dispatch_collection: "_ClsLevelDispatch[_ET]", + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + dispatch_collection: _ClsLevelDispatch[_ET], ) -> str: since, args, conv = dispatch_collection.legacy_signatures[0] return ( @@ -219,8 +219,8 @@ def _version_signature_changes( def _augment_fn_docs( - dispatch_collection: "_ClsLevelDispatch[_ET]", - parent_dispatch_cls: Type["_HasEventsDispatch[_ET]"], + dispatch_collection: _ClsLevelDispatch[_ET], + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], fn: _ListenerFnType, ) -> str: header = ( diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index 88edba3284..fa46a46c4b 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -49,7 +49,7 @@ class HasDescriptionCode: code = kw.pop("code", None) if code is not None: self.code = code - super(HasDescriptionCode, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) def _code_str(self) -> str: if not self.code: @@ -65,7 +65,7 @@ class HasDescriptionCode: ) def __str__(self) -> str: - message = super(HasDescriptionCode, self).__str__() + message = super().__str__() if self.code: message = "%s %s" % (message, self._code_str()) return message @@ -134,9 +134,7 @@ class ObjectNotExecutableError(ArgumentError): """ def __init__(self, target: Any): - super(ObjectNotExecutableError, self).__init__( - "Not an executable object: %r" % target - ) + super().__init__("Not an executable object: %r" % target) self.target = target def __reduce__(self) -> Union[str, Tuple[Any, ...]]: @@ -223,7 +221,7 @@ class UnsupportedCompilationError(CompileError): element_type: Type[ClauseElement], message: Optional[str] = None, ): - super(UnsupportedCompilationError, self).__init__( + super().__init__( "Compiler %r can't render element of type %s%s" % (compiler, element_type, ": %s" % message if message else "") ) @@ -557,7 +555,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> StatementError: ... @@ -572,7 +570,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> DontWrapMixin: ... @@ -587,7 +585,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> BaseException: ... @@ -601,7 +599,7 @@ class DBAPIError(StatementError): dbapi_base_err: Type[Exception], hide_parameters: bool = False, connection_invalidated: bool = False, - dialect: Optional["Dialect"] = None, + dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, ) -> Union[BaseException, DontWrapMixin]: # Don't ever wrap these, just return them directly as if @@ -792,7 +790,7 @@ class Base20DeprecationWarning(SADeprecationWarning): def __str__(self) -> str: return ( - super(Base20DeprecationWarning, self).__str__() + super().__str__() + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)" ) diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index bfec091376..f4adf3d297 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -1071,7 +1071,7 @@ class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): if obj is None: return self else: - return super(AmbiguousAssociationProxyInstance, self).get(obj) + return super().get(obj) def __eq__(self, obj: object) -> NoReturn: self._ambiguous() diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 6d441c9e34..6eb30ba4c6 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -884,12 +884,12 @@ class AutomapBase: cls.metadata.reflect(autoload_with, **opts) with _CONFIGURE_MUTEX: - table_to_map_config = dict( - (m.local_table, m) + table_to_map_config = { + m.local_table: m for m in _DeferredMapperConfig.classes_for_base( cls, sort=False ) - ) + } many_to_many = [] diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index 7093de7325..48e57e2bc8 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -525,15 +525,13 @@ class Result: # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set( - [ - _get_params[col].key - for col, value in zip( - mapper.primary_key, primary_key_identity - ) - if value is None - ] - ) + nones = { + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + } _lcl_get_clause = sql_util.adapt_criterion_to_null( _lcl_get_clause, nones ) @@ -562,14 +560,12 @@ class Result: setup, tuple(elem is None for elem in primary_key_identity) ) - params = dict( - [ - (_get_params[primary_key].key, id_val) - for id_val, primary_key in zip( - primary_key_identity, mapper.primary_key - ) - ] - ) + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } result = list(bq.for_session(self.session).params(**params)) l = len(result) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 7afe2343d2..8f6e2ffcd9 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -28,7 +28,7 @@ __all__ = ["ShardedSession", "ShardedQuery"] class ShardedQuery(Query): def __init__(self, *args, **kwargs): - super(ShardedQuery, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.id_chooser = self.session.id_chooser self.query_chooser = self.session.query_chooser self.execute_chooser = self.session.execute_chooser @@ -88,7 +88,7 @@ class ShardedSession(Session): """ query_chooser = kwargs.pop("query_chooser", None) - super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + super().__init__(query_cls=query_cls, **kwargs) event.listen( self, "do_orm_execute", execute_and_instances, retval=True @@ -138,7 +138,7 @@ class ShardedSession(Session): """ if identity_token is not None: - return super(ShardedSession, self)._identity_lookup( + return super()._identity_lookup( mapper, primary_key_identity, identity_token=identity_token, @@ -149,7 +149,7 @@ class ShardedSession(Session): if lazy_loaded_from: q = q._set_lazyload_from(lazy_loaded_from) for shard_id in self.id_chooser(q, primary_key_identity): - obj = super(ShardedSession, self)._identity_lookup( + obj = super()._identity_lookup( mapper, primary_key_identity, identity_token=shard_id, diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index 5c5d267364..ce63123653 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -278,13 +278,9 @@ class index_property(hybrid_property): # noqa """ if mutable: - super(index_property, self).__init__( - self.fget, self.fset, self.fdel, self.expr - ) + super().__init__(self.fget, self.fset, self.fdel, self.expr) else: - super(index_property, self).__init__( - self.fget, None, None, self.expr - ) + super().__init__(self.fget, None, None, self.expr) self.attr_name = attr_name self.index = index self.default = default diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index 427e151dac..f36087ad93 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -165,7 +165,7 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): return factories def unregister(self, class_): - super(ExtendedInstrumentationRegistry, self).unregister(class_) + super().unregister(class_) if class_ in self._manager_finders: del self._manager_finders[class_] del self._state_finders[class_] @@ -321,7 +321,7 @@ class _ClassInstrumentationAdapter(ClassManager): self._adapted.instrument_attribute(self.class_, key, inst) def post_configure_attribute(self, key): - super(_ClassInstrumentationAdapter, self).post_configure_attribute(key) + super().post_configure_attribute(key) self._adapted.post_configure_attribute(self.class_, key, self[key]) def install_descriptor(self, key, inst): diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index bfc3459d03..f392a85a75 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -63,7 +63,7 @@ def apply_mypy_mapped_attr( ): break else: - util.fail(api, "Can't find mapped attribute {}".format(name), cls) + util.fail(api, f"Can't find mapped attribute {name}", cls) return None if stmt.type is None: diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 44a1768a89..a32bc9b529 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -90,7 +90,7 @@ class SQLAlchemyAttribute: info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface, - ) -> "SQLAlchemyAttribute": + ) -> SQLAlchemyAttribute: data = data.copy() typ = deserialize_and_fixup_type(data.pop("type"), api) return cls(typ=typ, info=info, **data) @@ -238,8 +238,7 @@ def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: and isinstance(stmt.expr[0], NameExpr) and stmt.expr[0].fullname == "typing.TYPE_CHECKING" ): - for substmt in stmt.body[0].body: - yield substmt + yield from stmt.body[0].body else: yield stmt diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index f08ffc68dc..b0615d95d6 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -143,7 +143,7 @@ def ordering_list( count_from: Optional[int] = None, ordering_func: Optional[OrderingFunc] = None, reorder_on_append: bool = False, -) -> Callable[[], "OrderingList"]: +) -> Callable[[], OrderingList]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -335,29 +335,29 @@ class OrderingList(List[_T]): self._set_order_value(entity, should_be) def append(self, entity): - super(OrderingList, self).append(entity) + super().append(entity) self._order_entity(len(self) - 1, entity, self.reorder_on_append) def _raw_append(self, entity): """Append without any ordering behavior.""" - super(OrderingList, self).append(entity) + super().append(entity) _raw_append = collection.adds(1)(_raw_append) def insert(self, index, entity): - super(OrderingList, self).insert(index, entity) + super().insert(index, entity) self._reorder() def remove(self, entity): - super(OrderingList, self).remove(entity) + super().remove(entity) adapter = collection_adapter(self) if adapter and adapter._referenced_by_owner: self._reorder() def pop(self, index=-1): - entity = super(OrderingList, self).pop(index) + entity = super().pop(index) self._reorder() return entity @@ -375,18 +375,18 @@ class OrderingList(List[_T]): self.__setitem__(i, entity[i]) else: self._order_entity(index, entity, True) - super(OrderingList, self).__setitem__(index, entity) + super().__setitem__(index, entity) def __delitem__(self, index): - super(OrderingList, self).__delitem__(index) + super().__delitem__(index) self._reorder() def __setslice__(self, start, end, values): - super(OrderingList, self).__setslice__(start, end, values) + super().__setslice__(start, end, values) self._reorder() def __delslice__(self, start, end): - super(OrderingList, self).__delslice__(start, end) + super().__delslice__(start, end) self._reorder() def __reduce__(self): diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index f7050b93fb..dd295c3ed4 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -63,10 +63,10 @@ def _add_default_handler(logger: logging.Logger) -> None: logger.addHandler(handler) -_logged_classes: Set[Type["Identified"]] = set() +_logged_classes: Set[Type[Identified]] = set() -def _qual_logger_name_for_cls(cls: Type["Identified"]) -> str: +def _qual_logger_name_for_cls(cls: Type[Identified]) -> str: return ( getattr(cls, "_sqla_logger_namespace", None) or cls.__module__ + "." + cls.__name__ diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 854bad986a..2c77111c1d 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1200,7 +1200,7 @@ class ScalarAttributeImpl(AttributeImpl): __slots__ = "_replace_token", "_append_token", "_remove_token" def __init__(self, *arg, **kw): - super(ScalarAttributeImpl, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) self._replace_token = self._append_token = AttributeEventToken( self, OP_REPLACE ) @@ -1628,7 +1628,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): compare_function=None, **kwargs, ): - super(CollectionAttributeImpl, self).__init__( + super().__init__( class_, key, callable_, diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index cfe4880039..181dbd4a28 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -242,11 +242,11 @@ def _bulk_update( search_keys = {mapper._version_id_prop.key}.union(search_keys) def _changed_dict(mapper, state): - return dict( - (k, v) + return { + k: v for k, v in state.dict.items() if k in state.committed_state or k in search_keys - ) + } if isstates: if update_changed_only: @@ -1701,7 +1701,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): value_evaluators[key] = _evaluator evaluated_keys = list(value_evaluators.keys()) - attrib = set(k for k, v in resolved_keys_as_propnames) + attrib = {k for k, v in resolved_keys_as_propnames} states = set() for obj, state, dict_ in matched_objects: diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 99a51c998f..b957dc5d49 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -187,9 +187,9 @@ class _MultipleClassMarker(ClsRegistryToken): on_remove: Optional[Callable[[], None]] = None, ): self.on_remove = on_remove - self.contents = set( - [weakref.ref(item, self._remove_item) for item in classes] - ) + self.contents = { + weakref.ref(item, self._remove_item) for item in classes + } _registries.add(self) def remove_item(self, cls: Type[Any]) -> None: @@ -224,13 +224,11 @@ class _MultipleClassMarker(ClsRegistryToken): # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] - modules = set( - [ - cls.__module__ - for cls in [ref() for ref in self.contents] - if cls is not None - ] - ) + modules = { + cls.__module__ + for cls in [ref() for ref in self.contents] + if cls is not None + } if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 60ccecdb7c..621b3e5d74 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -838,12 +838,10 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): return self def get_children(self, **kw): - for elem in itertools.chain.from_iterable( + yield from itertools.chain.from_iterable( element._from_objects for element in self._raw_columns - ): - yield elem - for elem in super(FromStatement, self).get_children(**kw): - yield elem + ) + yield from super().get_children(**kw) @property def _all_selected_columns(self): @@ -1245,14 +1243,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ): ens = element._annotations["entity_namespace"] if not ens.is_mapper and not ens.is_aliased_class: - for elem in _select_iterables([element]): - yield elem + yield from _select_iterables([element]) else: - for elem in _select_iterables(ens._all_column_expressions): - yield elem + yield from _select_iterables(ens._all_column_expressions) else: - for elem in _select_iterables([element]): - yield elem + yield from _select_iterables([element]) @classmethod def get_columns_clause_froms(cls, statement): diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index c233298b9f..268a1d57a7 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -342,9 +342,7 @@ class _ImperativeMapperConfig(_MapperConfig): table: Optional[FromClause], mapper_kw: _MapperKwArgs, ): - super(_ImperativeMapperConfig, self).__init__( - registry, cls_, mapper_kw - ) + super().__init__(registry, cls_, mapper_kw) self.local_table = self.set_cls_attribute("__table__", table) @@ -480,7 +478,7 @@ class _ClassScanMapperConfig(_MapperConfig): self.clsdict_view = ( util.immutabledict(dict_) if dict_ else util.EMPTY_DICT ) - super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) + super().__init__(registry, cls_, mapper_kw) self.registry = registry self.persist_selectable = None @@ -1636,13 +1634,11 @@ class _ClassScanMapperConfig(_MapperConfig): inherited_table = inherited_mapper.local_table if "exclude_properties" not in mapper_args: - mapper_args["exclude_properties"] = exclude_properties = set( - [ - c.key - for c in inherited_table.c - if c not in inherited_mapper._columntoproperty - ] - ).union(inherited_mapper.exclude_properties or ()) + mapper_args["exclude_properties"] = exclude_properties = { + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + }.union(inherited_mapper.exclude_properties or ()) exclude_properties.difference_update( [c.key for c in self.declared_columns] ) @@ -1758,7 +1754,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): if not sort: return classes_for_base - all_m_by_cls = dict((m.cls, m) for m in classes_for_base) + all_m_by_cls = {m.cls: m for m in classes_for_base} tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = [] for m_cls in all_m_by_cls: @@ -1771,7 +1767,7 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: self._configs.pop(self._cls, None) - return super(_DeferredMapperConfig, self).map(mapper_kw) + return super().map(mapper_kw) def _add_attribute( diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 73e2ee9349..32de155a15 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -111,7 +111,7 @@ class InstrumentationEvents(event.Events): @classmethod def _clear(cls): - super(InstrumentationEvents, cls)._clear() + super()._clear() instrumentation._instrumentation_factory.dispatch._clear() def class_instrument(self, cls): @@ -266,7 +266,7 @@ class InstanceEvents(event.Events): @classmethod def _clear(cls): - super(InstanceEvents, cls)._clear() + super()._clear() _InstanceEventsHold._clear() def first_init(self, manager, cls): @@ -798,7 +798,7 @@ class MapperEvents(event.Events): @classmethod def _clear(cls): - super(MapperEvents, cls)._clear() + super()._clear() _MapperEventsHold._clear() def instrument_class(self, mapper, class_): diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 33de2aee90..dfe09fcbd3 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -399,8 +399,7 @@ class ClassManager( if mgr is not None and mgr is not self: yield mgr if recursive: - for m in mgr.subclass_managers(True): - yield m + yield from mgr.subclass_managers(True) def post_configure_attribute(self, key): _instrumentation_factory.dispatch.attribute_instrument( diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 64f2542fda..edfa61287f 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -535,15 +535,11 @@ def load_on_pk_identity( # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set( - [ - _get_params[col].key - for col, value in zip( - mapper.primary_key, primary_key_identity - ) - if value is None - ] - ) + nones = { + _get_params[col].key + for col, value in zip(mapper.primary_key, primary_key_identity) + if value is None + } _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) @@ -558,14 +554,12 @@ def load_on_pk_identity( sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}), ) - params = dict( - [ - (_get_params[primary_key].key, id_val) - for id_val, primary_key in zip( - primary_key_identity, mapper.primary_key - ) - ] - ) + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } else: params = None diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 1aa864f7e5..8bacb87df4 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -175,7 +175,7 @@ class _AttrGetter: def attribute_keyed_dict( attr_name: str, *, ignore_unpopulated_attribute: bool = False -) -> Type["KeyFuncDict"]: +) -> Type[KeyFuncDict]: """A dictionary-based collection type with attribute-based keying. .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to @@ -226,7 +226,7 @@ def keyfunc_mapping( keyfunc: Callable[[Any], _KT], *, ignore_unpopulated_attribute: bool = False, -) -> Type["KeyFuncDict[_KT, Any]"]: +) -> Type[KeyFuncDict[_KT, Any]]: """A dictionary-based collection type with arbitrary keying. .. versionchanged:: 2.0 Renamed :data:`.mapped_collection` to diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5f7ff43e42..d15c882c40 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -143,8 +143,7 @@ def _all_registries() -> Set[registry]: def _unconfigured_mappers() -> Iterator[Mapper[Any]]: for reg in _all_registries(): - for mapper in reg._mappers_to_configure(): - yield mapper + yield from reg._mappers_to_configure() _already_compiling = False @@ -905,8 +904,8 @@ class Mapper( with_polymorphic: Optional[ Tuple[ - Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], - Optional["FromClause"], + Union[Literal["*"], Sequence[Union[Mapper[Any], Type[Any]]]], + Optional[FromClause], ] ] @@ -2518,105 +2517,85 @@ class Mapper( @HasMemoized_ro_memoized_attribute def _insert_cols_evaluating_none(self): - return dict( - ( - table, - frozenset( - col for col in columns if col.type.should_evaluate_none - ), + return { + table: frozenset( + col for col in columns if col.type.should_evaluate_none ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _insert_cols_as_none(self): - return dict( - ( - table, - frozenset( - col.key - for col in columns - if not col.primary_key - and not col.server_default - and not col.default - and not col.type.should_evaluate_none - ), + return { + table: frozenset( + col.key + for col in columns + if not col.primary_key + and not col.server_default + and not col.default + and not col.type.should_evaluate_none ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _propkey_to_col(self): - return dict( - ( - table, - dict( - (self._columntoproperty[col].key, col) for col in columns - ), - ) + return { + table: {self._columntoproperty[col].key: col for col in columns} for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _pk_keys_by_table(self): - return dict( - (table, frozenset([col.key for col in pks])) + return { + table: frozenset([col.key for col in pks]) for table, pks in self._pks_by_table.items() - ) + } @HasMemoized.memoized_attribute def _pk_attr_keys_by_table(self): - return dict( - ( - table, - frozenset([self._columntoproperty[col].key for col in pks]), - ) + return { + table: frozenset([self._columntoproperty[col].key for col in pks]) for table, pks in self._pks_by_table.items() - ) + } @HasMemoized.memoized_attribute def _server_default_cols( self, ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: - return dict( - ( - table, - frozenset( - [ - col - for col in cast("Iterable[Column[Any]]", columns) - if col.server_default is not None - or ( - col.default is not None - and col.default.is_clause_element - ) - ] - ), + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_default is not None + or ( + col.default is not None + and col.default.is_clause_element + ) + ] ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _server_onupdate_default_cols( self, ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: - return dict( - ( - table, - frozenset( - [ - col - for col in cast("Iterable[Column[Any]]", columns) - if col.server_onupdate is not None - or ( - col.onupdate is not None - and col.onupdate.is_clause_element - ) - ] - ), + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_onupdate is not None + or ( + col.onupdate is not None + and col.onupdate.is_clause_element + ) + ] ) for table, columns in self._cols_by_table.items() - ) + } @HasMemoized.memoized_attribute def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]: diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index dfb61c28ac..77532f3233 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -458,12 +458,12 @@ def _collect_update_commands( if bulk: # keys here are mapped attribute keys, so # look at mapper attribute keys for pk - params = dict( - (propkey_to_col[propkey].key, state_dict[propkey]) + params = { + propkey_to_col[propkey].key: state_dict[propkey] for propkey in set(propkey_to_col) .intersection(state_dict) .difference(mapper._pk_attr_keys_by_table[table]) - ) + } has_all_defaults = True else: params = {} @@ -542,12 +542,12 @@ def _collect_update_commands( if bulk: # keys here are mapped attribute keys, so # look at mapper attribute keys for pk - pk_params = dict( - (propkey_to_col[propkey]._label, state_dict.get(propkey)) + pk_params = { + propkey_to_col[propkey]._label: state_dict.get(propkey) for propkey in set(propkey_to_col).intersection( mapper._pk_attr_keys_by_table[table] ) - ) + } else: pk_params = {} for col in pks: @@ -1689,7 +1689,7 @@ def _connections_for_states(base_mapper, uowtransaction, states): def _sort_states(mapper, states): pending = set(states) - persistent = set(s for s in pending if s.key is not None) + persistent = {s for s in pending if s.key is not None} pending.difference_update(persistent) try: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0cbd3f7135..c1da267f4d 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -151,9 +151,7 @@ class ColumnProperty( doc: Optional[str] = None, _instrument: bool = True, ): - super(ColumnProperty, self).__init__( - attribute_options=attribute_options - ) + super().__init__(attribute_options=attribute_options) columns = (column,) + additional_columns self.columns = [ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns @@ -211,7 +209,7 @@ class ColumnProperty( column.name = key @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: return self @property @@ -601,7 +599,7 @@ class MappedColumn( return self.column.name @property - def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: if self.deferred: return ColumnProperty( self.column, diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 2d97754f40..0d8d21df09 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -3281,7 +3281,7 @@ class BulkUpdate(BulkUD): values: Dict[_DMLColumnArgument, Any], update_kwargs: Optional[Dict[Any, Any]], ): - super(BulkUpdate, self).__init__(query) + super().__init__(query) self.values = values self.update_kwargs = update_kwargs diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 986093e025..020fae600e 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2966,15 +2966,13 @@ class JoinCondition: # 2. columns that are FK but are not remote (e.g. local) # suggest manytoone. - manytoone_local = set( - [ - c - for c in self._gather_columns_with_annotation( - self.primaryjoin, "foreign" - ) - if "remote" not in c._annotations - ] - ) + manytoone_local = { + c + for c in self._gather_columns_with_annotation( + self.primaryjoin, "foreign" + ) + if "remote" not in c._annotations + } # 3. if both collections are present, remove columns that # refer to themselves. This is for the case of @@ -3204,13 +3202,11 @@ class JoinCondition: self, clause: ColumnElement[Any], *annotation: Iterable[str] ) -> Set[ColumnElement[Any]]: annotation_set = set(annotation) - return set( - [ - cast(ColumnElement[Any], col) - for col in visitors.iterate(clause, {}) - if annotation_set.issubset(col._annotations) - ] - ) + return { + cast(ColumnElement[Any], col) + for col in visitors.iterate(clause, {}) + if annotation_set.issubset(col._annotations) + } def join_targets( self, diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index b65774f0af..efa0dc680e 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -157,7 +157,7 @@ class UninstrumentedColumnLoader(LoaderStrategy): __slots__ = ("columns",) def __init__(self, parent, strategy_key): - super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.columns = self.parent_property.columns def setup_query( @@ -197,7 +197,7 @@ class ColumnLoader(LoaderStrategy): __slots__ = "columns", "is_composite" def __init__(self, parent, strategy_key): - super(ColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.columns = self.parent_property.columns self.is_composite = hasattr(self.parent_property, "composite_class") @@ -285,7 +285,7 @@ class ColumnLoader(LoaderStrategy): @properties.ColumnProperty.strategy_for(query_expression=True) class ExpressionColumnLoader(ColumnLoader): def __init__(self, parent, strategy_key): - super(ExpressionColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) # compare to the "default" expression that is mapped in # the column. If it's sql.null, we don't need to render @@ -381,7 +381,7 @@ class DeferredColumnLoader(LoaderStrategy): __slots__ = "columns", "group", "raiseload" def __init__(self, parent, strategy_key): - super(DeferredColumnLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) if hasattr(self.parent_property, "composite_class"): raise NotImplementedError( "Deferred loading for composite " "types not implemented yet" @@ -582,7 +582,7 @@ class AbstractRelationshipLoader(LoaderStrategy): __slots__ = "mapper", "target", "uselist", "entity" def __init__(self, parent, strategy_key): - super(AbstractRelationshipLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.mapper = self.parent_property.mapper self.entity = self.parent_property.entity self.target = self.parent_property.target @@ -682,7 +682,7 @@ class LazyLoader( def __init__( self, parent: RelationshipProperty[Any], strategy_key: Tuple[Any, ...] ): - super(LazyLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self._raise_always = self.strategy_opts["lazy"] == "raise" self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" @@ -1431,7 +1431,7 @@ class SubqueryLoader(PostLoader): __slots__ = ("join_depth",) def __init__(self, parent, strategy_key): - super(SubqueryLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth def init_class_attribute(self, mapper): @@ -1560,7 +1560,7 @@ class SubqueryLoader(PostLoader): elif distinct_target_key is None: # if target_cols refer to a non-primary key or only # part of a composite primary key, set the q as distinct - for t in set(c.table for c in target_cols): + for t in {c.table for c in target_cols}: if not set(target_cols).issuperset(t.primary_key): q._distinct = True break @@ -2078,7 +2078,7 @@ class JoinedLoader(AbstractRelationshipLoader): __slots__ = "join_depth", "_aliased_class_pool" def __init__(self, parent, strategy_key): - super(JoinedLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth self._aliased_class_pool = [] @@ -2832,7 +2832,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): _chunksize = 500 def __init__(self, parent, strategy_key): - super(SelectInLoader, self).__init__(parent, strategy_key) + super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth is_m2o = self.parent_property.direction is interfaces.MANYTOONE diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index 23b3466f59..1c48bc4767 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -1872,7 +1872,7 @@ class _AttributeStrategyLoad(_LoadElement): ), ] - _of_type: Union["Mapper[Any]", "AliasedInsp[Any]", None] + _of_type: Union[Mapper[Any], AliasedInsp[Any], None] _path_with_polymorphic_path: Optional[PathRegistry] is_class_strategy = False diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 5e66653a38..9a8c02b6b6 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -411,9 +411,9 @@ class UOWTransaction: if cycles: # if yes, break the per-mapper actions into # per-state actions - convert = dict( - (rec, set(rec.per_state_flush_actions(self))) for rec in cycles - ) + convert = { + rec: set(rec.per_state_flush_actions(self)) for rec in cycles + } # rewrite the existing dependencies to point to # the per-state actions for those per-mapper actions @@ -435,9 +435,9 @@ class UOWTransaction: for dep in convert[edge[1]]: self.dependencies.add((edge[0], dep)) - return set( - [a for a in self.postsort_actions.values() if not a.disabled] - ).difference(cycles) + return { + a for a in self.postsort_actions.values() if not a.disabled + }.difference(cycles) def execute(self) -> None: postsort_actions = self._generate_actions() @@ -478,9 +478,9 @@ class UOWTransaction: return states = set(self.states) - isdel = set( + isdel = { s for (s, (isdelete, listonly)) in self.states.items() if isdelete - ) + } other = states.difference(isdel) if isdel: self.session._remove_newly_deleted(isdel) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 50eba5d4c4..e5bdbaa4f3 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1011,7 +1011,7 @@ class AliasedInsp( our_classes = util.to_set( mp.class_ for mp in self.with_polymorphic_mappers ) - new_classes = set([mp.class_ for mp in other.with_polymorphic_mappers]) + new_classes = {mp.class_ for mp in other.with_polymorphic_mappers} if our_classes == new_classes: return other else: @@ -1278,8 +1278,7 @@ class LoaderCriteriaOption(CriteriaOption): def _all_mappers(self) -> Iterator[Mapper[Any]]: if self.entity: - for mp_ent in self.entity.mapper.self_and_descendants: - yield mp_ent + yield from self.entity.mapper.self_and_descendants else: assert self.root_entity stack = list(self.root_entity.__subclasses__()) @@ -1290,8 +1289,7 @@ class LoaderCriteriaOption(CriteriaOption): inspection.inspect(subclass, raiseerr=False), ) if ent: - for mp in ent.mapper.self_and_descendants: - yield mp + yield from ent.mapper.self_and_descendants else: stack.extend(subclass.__subclasses__()) diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 8b8f6b010e..7c5281beeb 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -1127,7 +1127,7 @@ def label( name: str, element: _ColumnExpressionArgument[_T], type_: Optional[_TypeEngineArgument[_T]] = None, -) -> "Label[_T]": +) -> Label[_T]: """Return a :class:`Label` object for the given :class:`_expression.ColumnElement`. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 34b2951137..c818911696 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -291,16 +291,14 @@ def _cloned_intersection(a, b): """ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( - elem for elem in a if all_overlap.intersection(elem._cloned_set) - ) + return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} def _cloned_difference(a, b): all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) - return set( + return { elem for elem in a if not all_overlap.intersection(elem._cloned_set) - ) + } class _DialectArgView(MutableMapping[str, Any]): diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 8074bcf8b1..f48a3ccb00 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -782,7 +782,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl): else: advice = None - return super(ExpressionElementImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) @@ -1096,7 +1096,7 @@ class LabeledColumnExprImpl(ExpressionElementImpl): if isinstance(resolved, roles.ExpressionElementRole): return resolved.label(None) else: - new = super(LabeledColumnExprImpl, self)._implicit_coercions( + new = super()._implicit_coercions( element, resolved, argname=argname, **kw ) if isinstance(new, roles.ExpressionElementRole): @@ -1123,7 +1123,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): f"{', '.join(repr(e) for e in element)})?" ) - return super(ColumnsClauseImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) @@ -1370,7 +1370,7 @@ class CompoundElementImpl(_NoTextCoercion, RoleImpl): ) else: advice = None - return super(CompoundElementImpl, self)._raise_for_expected( + return super()._raise_for_expected( element, argname=argname, resolved=resolved, advice=advice, **kw ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 9a00afc91c..17aafddadb 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -115,104 +115,102 @@ if typing.TYPE_CHECKING: _FromHintsType = Dict["FromClause", str] -RESERVED_WORDS = set( - [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "authorization", - "between", - "binary", - "both", - "case", - "cast", - "check", - "collate", - "column", - "constraint", - "create", - "cross", - "current_date", - "current_role", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "for", - "foreign", - "freeze", - "from", - "full", - "grant", - "group", - "having", - "ilike", - "in", - "initially", - "inner", - "intersect", - "into", - "is", - "isnull", - "join", - "leading", - "left", - "like", - "limit", - "localtime", - "localtimestamp", - "natural", - "new", - "not", - "notnull", - "null", - "off", - "offset", - "old", - "on", - "only", - "or", - "order", - "outer", - "overlaps", - "placing", - "primary", - "references", - "right", - "select", - "session_user", - "set", - "similar", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "verbose", - "when", - "where", - ] -) +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", +} LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) @@ -505,8 +503,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): "between each element to resolve." ) froms_str = ", ".join( - '"{elem}"'.format(elem=self.froms[from_]) - for from_ in froms + f'"{self.froms[from_]}"' for from_ in froms ) message = template.format( froms=froms_str, start=self.froms[start_with] @@ -1259,11 +1256,8 @@ class SQLCompiler(Compiled): # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve - return dict( - ( - key, - value, - ) # type: ignore + return { + key: value # type: ignore for key, value in ( ( self.bind_names[bindparam], @@ -1277,7 +1271,7 @@ class SQLCompiler(Compiled): for bindparam in self.bind_names ) if value is not None - ) + } def is_subquery(self): return len(self.stack) > 1 @@ -4147,17 +4141,12 @@ class SQLCompiler(Compiled): def _setup_select_hints( self, select: Select[Any] ) -> Tuple[str, _FromHintsType]: - byfrom = dict( - [ - ( - from_, - hinttext - % {"name": from_._compiler_dispatch(self, ashint=True)}, - ) - for (from_, dialect), hinttext in select._hints.items() - if dialect in ("*", self.dialect.name) - ] - ) + byfrom = { + from_: hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)} + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + } hint_text = self.get_select_hint_text(byfrom) return hint_text, byfrom @@ -4583,13 +4572,11 @@ class SQLCompiler(Compiled): ) def _setup_crud_hints(self, stmt, table_text): - dialect_hints = dict( - [ - (table, hint_text) - for (table, dialect), hint_text in stmt._hints.items() - if dialect in ("*", self.dialect.name) - ] - ) + dialect_hints = { + table: hint_text + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + } if stmt.table in dialect_hints: table_text = self.format_from_hint_text( table_text, stmt.table, dialect_hints[stmt.table], True @@ -5318,9 +5305,7 @@ class StrSQLCompiler(SQLCompiler): if not isinstance(compiler, StrSQLCompiler): return compiler.process(element) - return super(StrSQLCompiler, self).visit_unsupported_compilation( - element, err - ) + return super().visit_unsupported_compilation(element, err) def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( @@ -6603,14 +6588,14 @@ class IdentifierPreparer: @util.memoized_property def _r_identifiers(self): - initial, final, escaped_final = [ + initial, final, escaped_final = ( re.escape(s) for s in ( self.initial_quote, self.final_quote, self._escape_identifier(self.final_quote), ) - ] + ) r = re.compile( r"(?:" r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 31d127c2c3..017ff7baa0 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -227,15 +227,15 @@ def _get_crud_params( parameters = {} elif stmt_parameter_tuples: assert spd is not None - parameters = dict( - (_column_as_key(key), REQUIRED) + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys if key not in spd - ) + } else: - parameters = dict( - (_column_as_key(key), REQUIRED) for key in compiler.column_keys - ) + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys + } # create a list of column assignment clauses as tuples values: List[_CrudParamElement] = [] @@ -1278,10 +1278,10 @@ def _get_update_multitable_params( values, kw, ): - normalized_params = dict( - (coercions.expect(roles.DMLColumnRole, c), param) + normalized_params = { + coercions.expect(roles.DMLColumnRole, c): param for c, param in stmt_parameter_tuples - ) + } include_table = compile_state.include_table_with_column_exprs diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index fa0c25b1d3..ecdc2eb63d 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -176,7 +176,7 @@ class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): """ _ddl_if: Optional[DDLIf] = None - target: Optional["SchemaItem"] = None + target: Optional[SchemaItem] = None def _execute_on_connection( self, connection, distilled_params, execution_options @@ -1179,12 +1179,10 @@ class SchemaDropper(InvokeDropDDLBase): def sort_tables( - tables: Iterable["Table"], - skip_fn: Optional[Callable[["ForeignKeyConstraint"], bool]] = None, - extra_dependencies: Optional[ - typing_Sequence[Tuple["Table", "Table"]] - ] = None, -) -> List["Table"]: + tables: Iterable[Table], + skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None, + extra_dependencies: Optional[typing_Sequence[Tuple[Table, Table]]] = None, +) -> List[Table]: """Sort a collection of :class:`_schema.Table` objects based on dependency. diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 2d3e3598b8..c279e344b5 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -1179,7 +1179,7 @@ class Insert(ValuesBase): ) def __init__(self, table: _DMLTableArgument): - super(Insert, self).__init__(table) + super().__init__(table) @_generative def inline(self: SelfInsert) -> SelfInsert: @@ -1498,7 +1498,7 @@ class Update(DMLWhereBase, ValuesBase): ) def __init__(self, table: _DMLTableArgument): - super(Update, self).__init__(table) + super().__init__(table) @_generative def ordered_values( diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 044bdf585a..d9a1a93580 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3035,7 +3035,7 @@ class BooleanClauseList(ExpressionClauseList[bool]): if not self.clauses: return self else: - return super(BooleanClauseList, self).self_group(against=against) + return super().self_group(against=against) and_ = BooleanClauseList.and_ @@ -3082,7 +3082,7 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): ] self.type = sqltypes.TupleType(*[arg.type for arg in init_clauses]) - super(Tuple, self).__init__(*init_clauses) + super().__init__(*init_clauses) @property def _select_iterable(self) -> _SelectIterable: @@ -3753,8 +3753,8 @@ class BinaryExpression(OperatorExpression[_T]): if typing.TYPE_CHECKING: def __invert__( - self: "BinaryExpression[_T]", - ) -> "BinaryExpression[_T]": + self: BinaryExpression[_T], + ) -> BinaryExpression[_T]: ... @util.ro_non_memoized_property @@ -3772,7 +3772,7 @@ class BinaryExpression(OperatorExpression[_T]): modifiers=self.modifiers, ) else: - return super(BinaryExpression, self)._negate() + return super()._negate() class Slice(ColumnElement[Any]): @@ -4617,7 +4617,7 @@ class ColumnClause( if self.table is not None: return self.table.entity_namespace else: - return super(ColumnClause, self).entity_namespace + return super().entity_namespace def _clone(self, detect_subquery_cols=False, **kw): if ( @@ -4630,7 +4630,7 @@ class ColumnClause( new = table.c.corresponding_column(self) return new - return super(ColumnClause, self)._clone(**kw) + return super()._clone(**kw) @HasMemoized_ro_memoized_attribute def _from_objects(self) -> List[FromClause]: @@ -4993,7 +4993,7 @@ class AnnotatedColumnElement(Annotated): self.__dict__.pop(attr) def _with_annotations(self, values): - clone = super(AnnotatedColumnElement, self)._with_annotations(values) + clone = super()._with_annotations(values) clone.__dict__.pop("comparator", None) return clone @@ -5032,7 +5032,7 @@ class _truncated_label(quoted_name): def __new__(cls, value: str, quote: Optional[bool] = None) -> Any: quote = getattr(value, "quote", quote) # return super(_truncated_label, cls).__new__(cls, value, quote, True) - return super(_truncated_label, cls).__new__(cls, value, quote) + return super().__new__(cls, value, quote) def __reduce__(self) -> Any: return self.__class__, (str(self), self.quote) diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fad7c28eb5..5ed89bc824 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -167,9 +167,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): @property def _proxy_key(self): - return super(FunctionElement, self)._proxy_key or getattr( - self, "name", None - ) + return super()._proxy_key or getattr(self, "name", None) def _execute_on_connection( self, @@ -660,7 +658,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ): return Grouping(self) else: - return super(FunctionElement, self).self_group(against=against) + return super().self_group(against=against) @property def entity_namespace(self): @@ -1198,7 +1196,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): ] kwargs.setdefault("type_", _type_from_args(fn_args)) kwargs["_parsed_args"] = fn_args - super(ReturnTypeFromArgs, self).__init__(*fn_args, **kwargs) + super().__init__(*fn_args, **kwargs) class coalesce(ReturnTypeFromArgs[_T]): @@ -1304,7 +1302,7 @@ class count(GenericFunction[int]): def __init__(self, expression=None, **kwargs): if expression is None: expression = literal_column("*") - super(count, self).__init__(expression, **kwargs) + super().__init__(expression, **kwargs) class current_date(AnsiFunction[datetime.date]): @@ -1411,7 +1409,7 @@ class array_agg(GenericFunction[_T]): type_from_args, dimensions=1 ) kwargs["_parsed_args"] = fn_args - super(array_agg, self).__init__(*fn_args, **kwargs) + super().__init__(*fn_args, **kwargs) class OrderedSetAgg(GenericFunction[_T]): diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index bbfaf47e1b..26e3a21bb4 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -439,7 +439,7 @@ class DeferredLambdaElement(LambdaElement): lambda_args: Tuple[Any, ...] = (), ): self.lambda_args = lambda_args - super(DeferredLambdaElement, self).__init__(fn, role, opts) + super().__init__(fn, role, opts) def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) @@ -483,7 +483,7 @@ class DeferredLambdaElement(LambdaElement): def _copy_internals( self, clone=_clone, deferred_copy_internals=None, **kw ): - super(DeferredLambdaElement, self)._copy_internals( + super()._copy_internals( clone=clone, deferred_copy_internals=deferred_copy_internals, # **kw opts=kw, diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 55c275741f..2d1f9caa17 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -66,11 +66,11 @@ class OperatorType(Protocol): def __call__( self, - left: "Operators", + left: Operators, right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> "Operators": + ) -> Operators: ... @@ -184,7 +184,7 @@ class Operators: precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[Any]"], "TypeEngine[Any]"] + Union[Type[TypeEngine[Any]], TypeEngine[Any]] ] = None, python_impl: Optional[Callable[..., Any]] = None, ) -> Callable[[Any], Operators]: @@ -397,7 +397,7 @@ class custom_op(OperatorType, Generic[_T]): precedence: int = 0, is_comparison: bool = False, return_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + Union[Type[TypeEngine[_T]], TypeEngine[_T]] ] = None, natural_self_precedent: bool = False, eager_grouping: bool = False, diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index cd10d0c4a5..f76fc447c0 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -920,11 +920,11 @@ class Table( :attr:`_schema.Table.indexes` """ - return set( + return { fkc.constraint for fkc in self.foreign_keys if fkc.constraint is not None - ) + } def _init_existing(self, *args: Any, **kwargs: Any) -> None: autoload_with = kwargs.pop("autoload_with", None) @@ -1895,7 +1895,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): # name = None is expected to be an interim state # note this use case is legacy now that ORM declarative has a # dedicated "column" construct local to the ORM - super(Column, self).__init__(name, type_) # type: ignore + super().__init__(name, type_) # type: ignore self.key = key if key is not None else name # type: ignore self.primary_key = primary_key @@ -3573,7 +3573,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: column = parent assert isinstance(column, Column) - super(Sequence, self)._set_parent(column) + super()._set_parent(column) column._on_table_attach(self._set_table) def _copy(self) -> Sequence: @@ -3712,7 +3712,7 @@ class DefaultClause(FetchedValue): _reflected: bool = False, ) -> None: util.assert_arg_type(arg, (str, ClauseElement, TextClause), "arg") - super(DefaultClause, self).__init__(for_update) + super().__init__(for_update) self.arg = arg self.reflected = _reflected @@ -3914,9 +3914,9 @@ class ColumnCollectionMixin: # issue #3411 - don't do the per-column auto-attach if some of the # columns are specified as strings. - has_string_cols = set( + has_string_cols = { c for c in self._pending_colargs if c is not None - ).difference(col_objs) + }.difference(col_objs) if not has_string_cols: def _col_attached(column: Column[Any], table: Table) -> None: @@ -4434,7 +4434,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): return self.elements[0].column.table def _validate_dest_table(self, table: Table) -> None: - table_keys = set([elem._table_key() for elem in self.elements]) + table_keys = {elem._table_key() for elem in self.elements} if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] raise exc.ArgumentError( @@ -4624,7 +4624,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): **dialect_kw: Any, ) -> None: self._implicit_generated = _implicit_generated - super(PrimaryKeyConstraint, self).__init__( + super().__init__( *columns, name=name, deferrable=deferrable, @@ -4636,7 +4636,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: table = parent assert isinstance(table, Table) - super(PrimaryKeyConstraint, self)._set_parent(table) + super()._set_parent(table) if table.primary_key is not self: table.constraints.discard(table.primary_key) @@ -5219,13 +5219,9 @@ class MetaData(HasSchemaAttr): for fk in removed.foreign_keys: fk._remove_from_metadata(self) if self._schemas: - self._schemas = set( - [ - t.schema - for t in self.tables.values() - if t.schema is not None - ] - ) + self._schemas = { + t.schema for t in self.tables.values() if t.schema is not None + } def __getstate__(self) -> Dict[str, Any]: return { diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8c64dea9d3..fcffc324fb 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1301,12 +1301,12 @@ class Join(roles.DMLTableRole, FromClause): # run normal _copy_internals. the clones for # left and right will come from the clone function's # cache - super(Join, self)._copy_internals(clone=clone, **kw) + super()._copy_internals(clone=clone, **kw) self._reset_memoizations() def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: - super(Join, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) self.left._refresh_for_new_column(column) self.right._refresh_for_new_column(column) @@ -1467,7 +1467,7 @@ class Join(roles.DMLTableRole, FromClause): # "consider_as_foreign_keys". if consider_as_foreign_keys: for const in list(constraints): - if set(f.parent for f in const.elements) != set( + if {f.parent for f in const.elements} != set( consider_as_foreign_keys ): del constraints[const] @@ -1475,7 +1475,7 @@ class Join(roles.DMLTableRole, FromClause): # if still multiple constraints, but # they all refer to the exact same end result, use it. if len(constraints) > 1: - dedupe = set(tuple(crit) for crit in constraints.values()) + dedupe = {tuple(crit) for crit in constraints.values()} if len(dedupe) == 1: key = list(constraints)[0] constraints = {key: constraints[key]} @@ -1621,7 +1621,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): self.name = name def _refresh_for_new_column(self, column: ColumnElement[Any]) -> None: - super(AliasedReturnsRows, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) self.element._refresh_for_new_column(column) def _populate_column_collection(self): @@ -1654,7 +1654,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): ) -> None: existing_element = self.element - super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw) + super()._copy_internals(clone=clone, **kw) # the element clone is usually against a Table that returns the # same object. don't reset exported .c. collections and other @@ -1752,7 +1752,7 @@ class TableValuedAlias(LateralFromClause, Alias): table_value_type=None, joins_implicitly=False, ): - super(TableValuedAlias, self)._init(selectable, name=name) + super()._init(selectable, name=name) self.joins_implicitly = joins_implicitly self._tableval_type = ( @@ -1959,7 +1959,7 @@ class TableSample(FromClauseAlias): self.sampling = sampling self.seed = seed - super(TableSample, self)._init(selectable, name=name) + super()._init(selectable, name=name) def _get_method(self): return self.sampling @@ -2044,7 +2044,7 @@ class CTE( self._prefixes = _prefixes if _suffixes: self._suffixes = _suffixes - super(CTE, self)._init(selectable, name=name) + super()._init(selectable, name=name) def _populate_column_collection(self): if self._cte_alias is not None: @@ -2945,7 +2945,7 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): return None def __init__(self, name: str, *columns: ColumnClause[Any], **kw: Any): - super(TableClause, self).__init__() + super().__init__() self.name = name self._columns = DedupeColumnCollection() self.primary_key = ColumnSet() # type: ignore @@ -3156,7 +3156,7 @@ class Values(Generative, LateralFromClause): name: Optional[str] = None, literal_binds: bool = False, ): - super(Values, self).__init__() + super().__init__() self._column_args = columns if name is None: self._unnamed = True @@ -4188,7 +4188,7 @@ class CompoundSelectState(CompileState): # TODO: this is hacky and slow hacky_subquery = self.statement.subquery() hacky_subquery.named_with_column = False - d = dict((c.key, c) for c in hacky_subquery.c) + d = {c.key: c for c in hacky_subquery.c} return d, d, d @@ -4369,7 +4369,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows): ) def _refresh_for_new_column(self, column): - super(CompoundSelect, self)._refresh_for_new_column(column) + super()._refresh_for_new_column(column) for select in self.selects: select._refresh_for_new_column(column) @@ -4689,16 +4689,16 @@ class SelectState(util.MemoizedSlots, CompileState): Dict[str, ColumnElement[Any]], Dict[str, ColumnElement[Any]], ]: - with_cols: Dict[str, ColumnElement[Any]] = dict( - (c._tq_label or c.key, c) # type: ignore + with_cols: Dict[str, ColumnElement[Any]] = { + c._tq_label or c.key: c # type: ignore for c in self.statement._all_selected_columns if c._allow_label_resolve - ) - only_froms: Dict[str, ColumnElement[Any]] = dict( - (c.key, c) # type: ignore + } + only_froms: Dict[str, ColumnElement[Any]] = { + c.key: c # type: ignore for c in _select_iterables(self.froms) if c._allow_label_resolve - ) + } only_cols: Dict[str, ColumnElement[Any]] = with_cols.copy() for key, value in only_froms.items(): with_cols.setdefault(key, value) @@ -5569,7 +5569,7 @@ class Select( # 2. copy FROM collections, adding in joins that we've created. existing_from_obj = [clone(f, **kw) for f in self._from_obj] add_froms = ( - set(f for f in new_froms.values() if isinstance(f, Join)) + {f for f in new_froms.values() if isinstance(f, Join)} .difference(all_the_froms) .difference(existing_from_obj) ) @@ -5589,15 +5589,13 @@ class Select( # correlate_except, setup_joins, these clone normally. For # column-expression oriented things like raw_columns, where_criteria, # order by, we get this from the new froms. - super(Select, self)._copy_internals( - clone=clone, omit_attrs=("_from_obj",), **kw - ) + super()._copy_internals(clone=clone, omit_attrs=("_from_obj",), **kw) self._reset_memoizations() def get_children(self, **kw: Any) -> Iterable[ClauseElement]: return itertools.chain( - super(Select, self).get_children( + super().get_children( omit_attrs=("_from_obj", "_correlate", "_correlate_except"), **kw, ), diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b98a16b6fb..624b7d16ef 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -134,9 +134,7 @@ class Concatenable(TypeEngineMixin): ): return operators.concat_op, self.expr.type else: - return super(Concatenable.Comparator, self)._adapt_expression( - op, other_comparator - ) + return super()._adapt_expression(op, other_comparator) comparator_factory: _ComparatorFactory[Any] = Comparator @@ -319,7 +317,7 @@ class Unicode(String): Parameters are the same as that of :class:`.String`. """ - super(Unicode, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class UnicodeText(Text): @@ -344,7 +342,7 @@ class UnicodeText(Text): Parameters are the same as that of :class:`_expression.TextClause`. """ - super(UnicodeText, self).__init__(length=length, **kwargs) + super().__init__(length=length, **kwargs) class Integer(HasExpressionLookup, TypeEngine[int]): @@ -930,7 +928,7 @@ class _Binary(TypeEngine[bytes]): if isinstance(value, str): return self else: - return super(_Binary, self).coerce_compared_value(op, value) + return super().coerce_compared_value(op, value) def get_dbapi_type(self, dbapi): return dbapi.BINARY @@ -1450,7 +1448,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): self._valid_lookup[None] = self._object_lookup[None] = None - super(Enum, self).__init__(length=length) + super().__init__(length=length) if self.enum_class: kw.setdefault("name", self.enum_class.__name__.lower()) @@ -1551,9 +1549,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): op: OperatorType, other_comparator: TypeEngine.Comparator[Any], ) -> Tuple[OperatorType, TypeEngine[Any]]: - op, typ = super(Enum.Comparator, self)._adapt_expression( - op, other_comparator - ) + op, typ = super()._adapt_expression(op, other_comparator) if op is operators.concat_op: typ = String(self.type.length) return op, typ @@ -1618,7 +1614,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): def adapt(self, impltype, **kw): kw["_enums"] = self._enums_argument kw["_disable_warnings"] = True - return super(Enum, self).adapt(impltype, **kw) + return super().adapt(impltype, **kw) def _should_create_constraint(self, compiler, **kw): if not self._is_impl_for_variant(compiler.dialect, kw): @@ -1649,7 +1645,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): assert e.table is table def literal_processor(self, dialect): - parent_processor = super(Enum, self).literal_processor(dialect) + parent_processor = super().literal_processor(dialect) def process(value): value = self._db_value_for_elem(value) @@ -1660,7 +1656,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return process def bind_processor(self, dialect): - parent_processor = super(Enum, self).bind_processor(dialect) + parent_processor = super().bind_processor(dialect) def process(value): value = self._db_value_for_elem(value) @@ -1671,7 +1667,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return process def result_processor(self, dialect, coltype): - parent_processor = super(Enum, self).result_processor(dialect, coltype) + parent_processor = super().result_processor(dialect, coltype) def process(value): if parent_processor: @@ -1690,7 +1686,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): if self.enum_class: return self.enum_class else: - return super(Enum, self).python_type + return super().python_type class PickleType(TypeDecorator[object]): @@ -1739,7 +1735,7 @@ class PickleType(TypeDecorator[object]): self.protocol = protocol self.pickler = pickler or pickle self.comparator = comparator - super(PickleType, self).__init__() + super().__init__() if impl: # custom impl is not necessarily a LargeBinary subclass. @@ -2000,7 +1996,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): support a "day precision" parameter, i.e. Oracle. """ - super(Interval, self).__init__() + super().__init__() self.native = native self.second_precision = second_precision self.day_precision = day_precision @@ -3005,7 +3001,7 @@ class ARRAY( def _set_parent_with_dispatch(self, parent): """Support SchemaEventTarget""" - super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True) + super()._set_parent_with_dispatch(parent, outer=True) if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) @@ -3249,7 +3245,7 @@ class TIMESTAMP(DateTime): """ - super(TIMESTAMP, self).__init__(timezone=timezone) + super().__init__(timezone=timezone) def get_dbapi_type(self, dbapi): return dbapi.TIMESTAMP @@ -3464,7 +3460,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): @overload def __init__( - self: "Uuid[_python_UUID]", + self: Uuid[_python_UUID], as_uuid: Literal[True] = ..., native_uuid: bool = ..., ): @@ -3472,7 +3468,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): @overload def __init__( - self: "Uuid[str]", + self: Uuid[str], as_uuid: Literal[False] = ..., native_uuid: bool = ..., ): @@ -3628,11 +3624,11 @@ class UUID(Uuid[_UUID_RETURN]): __visit_name__ = "UUID" @overload - def __init__(self: "UUID[_python_UUID]", as_uuid: Literal[True] = ...): + def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): ... @overload - def __init__(self: "UUID[str]", as_uuid: Literal[False] = ...): + def __init__(self: UUID[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 1354073214..866c0ccde4 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -301,9 +301,7 @@ class _CopyInternalsTraversal(HasTraversalDispatch): def visit_string_clauseelement_dict( self, attrname, parent, element, clone=_clone, **kw ): - return dict( - (key, clone(value, **kw)) for key, value in element.items() - ) + return {key: clone(value, **kw) for key, value in element.items()} def visit_setup_join_tuple( self, attrname, parent, element, clone=_clone, **kw diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index cd57ee3b64..c3768c6c63 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1399,7 +1399,7 @@ class Emulated(TypeEngineMixin): def _is_native_for_emulated( typ: Type[Union[TypeEngine[Any], TypeEngineMixin]], -) -> TypeGuard["Type[NativeForEmulated]"]: +) -> TypeGuard[Type[NativeForEmulated]]: return hasattr(typ, "adapt_emulated_to_native") @@ -1673,9 +1673,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): if TYPE_CHECKING: assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types - return super(TypeDecorator.Comparator, self).operate( - op, *other, **kwargs - ) + return super().operate(op, *other, **kwargs) def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any @@ -1683,9 +1681,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): if TYPE_CHECKING: assert isinstance(self.expr.type, TypeDecorator) kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types - return super(TypeDecorator.Comparator, self).reverse_operate( - op, other, **kwargs - ) + return super().reverse_operate(op, other, **kwargs) @property def comparator_factory( # type: ignore # mypy properties bug diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index ec8ea757f2..14cbe24562 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -316,8 +316,7 @@ def visit_binary_product( if isinstance(element, ColumnClause): yield element for elem in element.get_children(): - for e in visit(elem): - yield e + yield from visit(elem) list(visit(expr)) visit = None # type: ignore # remove gc cycles @@ -433,12 +432,10 @@ def expand_column_list_from_order_by(collist, order_by): in the collist. """ - cols_already_present = set( - [ - col.element if col._order_by_label_element is not None else col - for col in collist - ] - ) + cols_already_present = { + col.element if col._order_by_label_element is not None else col + for col in collist + } to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by])) @@ -463,13 +460,10 @@ def clause_is_present(clause, search): def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]: if isinstance(clause, Join): - for t in tables_from_leftmost(clause.left): - yield t - for t in tables_from_leftmost(clause.right): - yield t + yield from tables_from_leftmost(clause.left) + yield from tables_from_leftmost(clause.right) elif isinstance(clause, FromGrouping): - for t in tables_from_leftmost(clause.element): - yield t + yield from tables_from_leftmost(clause.element) else: yield clause @@ -592,7 +586,7 @@ class _repr_row(_repr_base): __slots__ = ("row",) - def __init__(self, row: "Row[Any]", max_chars: int = 300): + def __init__(self, row: Row[Any], max_chars: int = 300): self.row = row self.max_chars = max_chars @@ -775,7 +769,7 @@ class _repr_params(_repr_base): ) return text - def _repr_param_tuple(self, params: "Sequence[Any]") -> str: + def _repr_param_tuple(self, params: Sequence[Any]) -> str: trunc = self.trunc ( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 2fda1e9cbe..d183372c34 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -269,9 +269,9 @@ class DialectSQL(CompiledSQL): return received_stmt == stmt def _received_statement(self, execute_observed): - received_stmt, received_params = super( - DialectSQL, self - )._received_statement(execute_observed) + received_stmt, received_params = super()._received_statement( + execute_observed + ) # TODO: why do we need this part? for real_stmt in execute_observed.statements: @@ -392,15 +392,15 @@ class EachOf(AssertRule): if self.rules and not self.rules[0].is_consumed: self.rules[0].no_more_statements() elif self.rules: - super(EachOf, self).no_more_statements() + super().no_more_statements() class Conditional(EachOf): def __init__(self, condition, rules, else_rules): if condition: - super(Conditional, self).__init__(*rules) + super().__init__(*rules) else: - super(Conditional, self).__init__(*else_rules) + super().__init__(*else_rules) class Or(AllOf): diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index c083f4e732..0a60a20d33 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -285,21 +285,21 @@ def reconnecting_engine(url=None, options=None): @typing.overload def testing_engine( - url: Optional["URL"] = None, + url: Optional[URL] = None, options: Optional[Dict[str, Any]] = None, asyncio: Literal[False] = False, transfer_staticpool: bool = False, -) -> "Engine": +) -> Engine: ... @typing.overload def testing_engine( - url: Optional["URL"] = None, + url: Optional[URL] = None, options: Optional[Dict[str, Any]] = None, asyncio: Literal[True] = True, transfer_staticpool: bool = False, -) -> "AsyncEngine": +) -> AsyncEngine: ... diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 25c6a04822..3cb060d018 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -129,10 +129,8 @@ class compound: for fail in self.fails: if fail(config): print( - ( - "%s failed as expected (%s): %s " - % (name, fail._as_string(config), ex) - ) + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), ex) ) break else: diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index dcee3f18bf..12b5acba46 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -814,7 +814,7 @@ class DeclarativeMappedTest(MappedTest): # sets up cls.Basic which is helpful for things like composite # classes - super(DeclarativeMappedTest, cls)._with_register_classes(fn) + super()._with_register_classes(fn) if cls._tables_metadata.tables and cls.run_create_tables: cls._tables_metadata.create_all(config.db) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 0a70f4008e..d590ecbe43 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -49,7 +49,7 @@ def pytest_addoption(parser): required=False, help=None, # noqa ): - super(CallableAction, self).__init__( + super().__init__( option_strings=option_strings, dest=dest, nargs=0, @@ -210,7 +210,7 @@ def pytest_collection_modifyitems(session, config, items): and not item.getparent(pytest.Class).name.startswith("_") ] - test_classes = set(item.getparent(pytest.Class) for item in items) + test_classes = {item.getparent(pytest.Class) for item in items} def collect(element): for inst_or_fn in element.collect(): diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 7672bcde5b..dfc3f28f6a 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -195,7 +195,7 @@ class ProfileStatsFile: def _read(self): try: profile_f = open(self.fname) - except IOError: + except OSError: return for lineno, line in enumerate(profile_f): line = line.strip() @@ -212,7 +212,7 @@ class ProfileStatsFile: profile_f.close() def _write(self): - print(("Writing profile file %s" % self.fname)) + print("Writing profile file %s" % self.fname) profile_f = open(self.fname, "w") profile_f.write(self._header()) for test_key in sorted(self.data): @@ -293,7 +293,7 @@ def count_functions(variance=0.05): else: line_no, expected_count = expected - print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) + print("Pstats calls: %d Expected %s" % (callcount, expected_count)) stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort)) stats.print_stats() if _profile_stats.dump: diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 33e395c480..01cec1fb06 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -1,4 +1,3 @@ -#! coding: utf-8 # mypy: ignore-errors diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 68d1c13faa..bf745095d2 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1871,14 +1871,12 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): # "unique constraints" are actually unique indexes (with possible # exception of a unique that is a dupe of another one in the case # of Oracle). make sure # they aren't duplicated. - idx_names = set([idx.name for idx in reflected.indexes]) - uq_names = set( - [ - uq.name - for uq in reflected.constraints - if isinstance(uq, sa.UniqueConstraint) - ] - ).difference(["unique_c_a_b"]) + idx_names = {idx.name for idx in reflected.indexes} + uq_names = { + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + }.difference(["unique_c_a_b"]) assert not idx_names.intersection(uq_names) if names_that_duplicate_index: @@ -2519,10 +2517,10 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): ) t.create(connection) eq_( - dict( - (col["name"], col["nullable"]) + { + col["name"]: col["nullable"] for col in inspect(connection).get_columns("t") - ), + }, {"a": True, "b": False}, ) @@ -2613,7 +2611,7 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): # that can reflect these, since alembic looks for this opts = insp.get_foreign_keys("table")[0]["options"] - eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + eq_({k: opts[k] for k in opts if opts[k]}, {}) opts = insp.get_foreign_keys("user")[0]["options"] eq_(opts, expected) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 838b740fd8..6394e4b9a1 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -552,7 +552,7 @@ class FetchLimitOffsetTest(fixtures.TablesTest): .offset(2) ).fetchall() eq_(fa[0], (3, 3, 4)) - eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) @testing.requires.fetch_ties @testing.requires.fetch_offset_with_options @@ -623,7 +623,7 @@ class FetchLimitOffsetTest(fixtures.TablesTest): .offset(2) ).fetchall() eq_(fa[0], (3, 3, 4)) - eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + eq_(set(fa), {(3, 3, 4), (4, 4, 5), (5, 4, 6)}) class SameNamedSchemaTableTest(fixtures.TablesTest): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 25ed041c2a..36fd7f247c 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -832,8 +832,8 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): result = {row[0] for row in connection.execute(t.select())} output = set(output) if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) + result = {filter_(x) for x in result} + output = {filter_(x) for x in output} eq_(result, output) if check_scale: eq_([str(x) for x in result], [str(x) for x in output]) @@ -969,13 +969,11 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): @testing.requires.precision_numerics_general def test_precision_decimal(self, do_numeric_test): - numbers = set( - [ - decimal.Decimal("54.234246451650"), - decimal.Decimal("0.004354"), - decimal.Decimal("900.0"), - ] - ) + numbers = { + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + } do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) @@ -988,52 +986,46 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): """ - numbers = set( - [ - decimal.Decimal("1E-2"), - decimal.Decimal("1E-3"), - decimal.Decimal("1E-4"), - decimal.Decimal("1E-5"), - decimal.Decimal("1E-6"), - decimal.Decimal("1E-7"), - decimal.Decimal("1E-8"), - decimal.Decimal("0.01000005940696"), - decimal.Decimal("0.00000005940696"), - decimal.Decimal("0.00000000000696"), - decimal.Decimal("0.70000000000696"), - decimal.Decimal("696E-12"), - ] - ) + numbers = { + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + } do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) @testing.requires.precision_numerics_enotation_large def test_enotation_decimal_large(self, do_numeric_test): """test exceedingly large decimals.""" - numbers = set( - [ - decimal.Decimal("4E+8"), - decimal.Decimal("5748E+15"), - decimal.Decimal("1.521E+15"), - decimal.Decimal("00000000000000.1E+12"), - ] - ) + numbers = { + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + } do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) @testing.requires.precision_numerics_many_significant_digits def test_many_significant_digits(self, do_numeric_test): - numbers = set( - [ - decimal.Decimal("31943874831932418390.01"), - decimal.Decimal("319438950232418390.273596"), - decimal.Decimal("87673.594069654243"), - ] - ) + numbers = { + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + } do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) @testing.requires.precision_numerics_retains_significant_digits def test_numeric_no_decimal(self, do_numeric_test): - numbers = set([decimal.Decimal("1.000")]) + numbers = {decimal.Decimal("1.000")} do_numeric_test( Numeric(precision=5, scale=3), numbers, numbers, check_scale=True ) @@ -1258,7 +1250,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def default(self, o): if isinstance(o, decimal.Decimal): return str(o) - return super(DecimalEncoder, self).default(o) + return super().default(o) json_data = json.dumps(data_element, cls=DecimalEncoder) diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 54be2e4e5b..22df745900 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -185,9 +185,7 @@ class Properties(Generic[_T]): return iter(list(self._data.values())) def __dir__(self) -> List[str]: - return dir(super(Properties, self)) + [ - str(k) for k in self._data.keys() - ] + return dir(super()) + [str(k) for k in self._data.keys()] def __add__(self, other: Properties[_F]) -> List[Union[_T, _F]]: return list(self) + list(other) # type: ignore @@ -477,8 +475,7 @@ def flatten_iterator(x: Iterable[_T]) -> Iterator[_T]: elem: _T for elem in x: if not isinstance(elem, str) and hasattr(elem, "__iter__"): - for y in flatten_iterator(elem): - yield y + yield from flatten_iterator(elem) else: yield elem @@ -504,7 +501,7 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): capacity: int threshold: float - size_alert: Optional[Callable[["LRUCache[_KT, _VT]"], None]] + size_alert: Optional[Callable[[LRUCache[_KT, _VT]], None]] def __init__( self, diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py index 969e8d92ea..ec94630193 100644 --- a/lib/sqlalchemy/util/_concurrency_py3k.py +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -32,7 +32,7 @@ if typing.TYPE_CHECKING: dead: bool gr_context: Optional[Context] - def __init__(self, fn: Callable[..., Any], driver: "greenlet"): + def __init__(self, fn: Callable[..., Any], driver: greenlet): ... def throw(self, *arg: Any) -> Any: diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 24f9bcf106..6517e381cb 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -64,11 +64,11 @@ def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec: if inspect.ismethod(func): func = func.__func__ if not inspect.isfunction(func): - raise TypeError("{!r} is not a Python function".format(func)) + raise TypeError(f"{func!r} is not a Python function") co = func.__code__ if not inspect.iscode(co): - raise TypeError("{!r} is not a code object".format(co)) + raise TypeError(f"{co!r} is not a code object") nargs = co.co_argcount names = co.co_varnames diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 051a8c89e1..8df4950a39 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1522,7 +1522,7 @@ class classproperty(property): fget: Callable[[Any], Any] def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any): - super(classproperty, self).__init__(fget, *arg, **kw) + super().__init__(fget, *arg, **kw) self.__doc__ = fget.__doc__ def __get__(self, obj: Any, cls: Optional[type] = None) -> Any: @@ -1793,7 +1793,7 @@ class _hash_limit_string(str): interpolated = (value % args) + ( " (this warning may be suppressed after %d occurrences)" % num ) - self = super(_hash_limit_string, cls).__new__(cls, interpolated) + self = super().__new__(cls, interpolated) self._hash = hash("%s_%d" % (value, hash(interpolated) % num)) return self diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 620e3bbb71..96aa5db2f3 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -72,8 +72,7 @@ def sort( """ for set_ in sort_as_subsets(tuples, allitems): - for s in set_: - yield s + yield from set_ def find_cycles( diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index ebafbc2b11..1162a54afd 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1664,7 +1664,7 @@ class CycleTest(_fixtures.FixtureTest): go() def test_visit_binary_product(self): - a, b, q, e, f, j, r = [column(chr_) for chr_ in "abqefjr"] + a, b, q, e, f, j, r = (column(chr_) for chr_ in "abqefjr") from sqlalchemy import and_, func from sqlalchemy.sql.util import visit_binary_product diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 23f03cc04f..a2d98fd53d 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -372,10 +372,10 @@ class DeferOptionsTest(NoCache, fixtures.MappedTest): [ A( id=i, - **dict( - (letter, "%s%d" % (letter, i)) + **{ + letter: "%s%d" % (letter, i) for letter in ["x", "y", "z", "p", "q", "r"] - ), + }, ) for i in range(1, 1001) ] diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 479a2472f4..f712b729cf 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -47,20 +47,20 @@ class ResultSetTest(fixtures.TablesTest, AssertsExecutionResults): conn.execute( t.insert(), [ - dict( - ("field%d" % fnum, "value%d" % fnum) + { + "field%d" % fnum: "value%d" % fnum for fnum in range(NUM_FIELDS) - ) + } for r_num in range(NUM_RECORDS) ], ) conn.execute( t2.insert(), [ - dict( - ("field%d" % fnum, "value%d" % fnum) + { + "field%d" % fnum: "value%d" % fnum for fnum in range(NUM_FIELDS) - ) + } for r_num in range(NUM_RECORDS) ], ) diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 9250c33341..ac95d7ab3c 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -122,19 +122,17 @@ class DependencySortTest(fixtures.TestBase): list(topological.sort(tuples, allitems)) assert False except exc.CircularDependencyError as err: - eq_(err.cycles, set(["node1", "node3", "node2", "node5", "node4"])) + eq_(err.cycles, {"node1", "node3", "node2", "node5", "node4"}) eq_( err.edges, - set( - [ - ("node3", "node1"), - ("node4", "node1"), - ("node2", "node3"), - ("node1", "node2"), - ("node4", "node5"), - ("node5", "node4"), - ] - ), + { + ("node3", "node1"), + ("node4", "node1"), + ("node2", "node3"), + ("node1", "node2"), + ("node4", "node5"), + ("node5", "node4"), + }, ) def test_raise_on_cycle_two(self): @@ -159,18 +157,16 @@ class DependencySortTest(fixtures.TestBase): list(topological.sort(tuples, allitems)) assert False except exc.CircularDependencyError as err: - eq_(err.cycles, set(["node1", "node3", "node2"])) + eq_(err.cycles, {"node1", "node3", "node2"}) eq_( err.edges, - set( - [ - ("node3", "node1"), - ("node2", "node3"), - ("node3", "node2"), - ("node1", "node2"), - ("node2", "node4"), - ] - ), + { + ("node3", "node1"), + ("node2", "node3"), + ("node3", "node2"), + ("node1", "node2"), + ("node2", "node4"), + }, ) def test_raise_on_cycle_three(self): @@ -225,7 +221,7 @@ class DependencySortTest(fixtures.TestBase): ] eq_( topological.find_cycles(tuples, self._nodes_from_tuples(tuples)), - set([node1, node2, node3]), + {node1, node2, node3}, ) def test_find_multiple_cycles_one(self): @@ -252,23 +248,29 @@ class DependencySortTest(fixtures.TestBase): (node3, node1), (node3, node2), ] - allnodes = set( - [node1, node2, node3, node4, node5, node6, node7, node8, node9] - ) + allnodes = { + node1, + node2, + node3, + node4, + node5, + node6, + node7, + node8, + node9, + } eq_( topological.find_cycles(tuples, allnodes), - set( - [ - "node8", - "node1", - "node2", - "node5", - "node4", - "node7", - "node6", - "node9", - ] - ), + { + "node8", + "node1", + "node2", + "node5", + "node4", + "node7", + "node6", + "node9", + }, ) def test_find_multiple_cycles_two(self): @@ -287,11 +289,11 @@ class DependencySortTest(fixtures.TestBase): (node2, node4), (node4, node1), ] - allnodes = set([node1, node2, node3, node4, node5, node6]) + allnodes = {node1, node2, node3, node4, node5, node6} # node6 only became present here once [ticket:2282] was addressed. eq_( topological.find_cycles(tuples, allnodes), - set(["node1", "node2", "node4", "node6"]), + {"node1", "node2", "node4", "node6"}, ) def test_find_multiple_cycles_three(self): @@ -312,7 +314,7 @@ class DependencySortTest(fixtures.TestBase): (node5, node6), (node6, node2), ] - allnodes = set([node1, node2, node3, node4, node5, node6]) + allnodes = {node1, node2, node3, node4, node5, node6} eq_(topological.find_cycles(tuples, allnodes), allnodes) def test_find_multiple_cycles_four(self): @@ -350,22 +352,20 @@ class DependencySortTest(fixtures.TestBase): allnodes = ["node%d" % i for i in range(1, 21)] eq_( topological.find_cycles(tuples, allnodes), - set( - [ - "node11", - "node10", - "node13", - "node15", - "node14", - "node17", - "node19", - "node20", - "node8", - "node1", - "node3", - "node2", - "node4", - "node6", - ] - ), + { + "node11", + "node10", + "node13", + "node15", + "node14", + "node17", + "node19", + "node20", + "node8", + "node1", + "node3", + "node2", + "node4", + "node6", + }, ) diff --git a/test/base/test_except.py b/test/base/test_except.py index 77f5c731ac..a458afb97c 100644 --- a/test/base/test_except.py +++ b/test/base/test_except.py @@ -1,5 +1,3 @@ -#! coding:utf-8 - """Tests exceptions and DB-API exception wrapping.""" from itertools import product diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 349ee8c058..b979d43bcc 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,5 +1,3 @@ -#! coding: utf-8 - import copy import inspect from pathlib import Path @@ -543,7 +541,7 @@ class ToListTest(fixtures.TestBase): eq_(util.to_list("xyz"), ["xyz"]) def test_from_set(self): - spec = util.to_list(set([1, 2, 3])) + spec = util.to_list({1, 2, 3}) assert isinstance(spec, list) eq_(sorted(spec), [1, 2, 3]) @@ -567,7 +565,7 @@ class ToListTest(fixtures.TestBase): class ColumnCollectionCommon(testing.AssertsCompiledSQL): def _assert_collection_integrity(self, coll): - eq_(coll._colset, set(c for k, c, _ in coll._collection)) + eq_(coll._colset, {c for k, c, _ in coll._collection}) d = {} for k, col, _ in coll._collection: d.setdefault(k, (k, col)) @@ -1964,7 +1962,7 @@ class IdentitySetTest(fixtures.TestBase): assert True try: - s = set([o1, o2]) + s = {o1, o2} s |= ids assert False except TypeError: @@ -2019,7 +2017,7 @@ class OrderedIdentitySetTest(fixtures.TestBase): class DictlikeIteritemsTest(fixtures.TestBase): - baseline = set([("a", 1), ("b", 2), ("c", 3)]) + baseline = {("a", 1), ("b", 2), ("c", 3)} def _ok(self, instance): iterator = util.dictlike_iteritems(instance) @@ -2966,7 +2964,7 @@ class GenericReprTest(fixtures.TestBase): self.e = e self.f = f self.g = g - super(Bar, self).__init__(**kw) + super().__init__(**kw) eq_( util.generic_repr( @@ -2989,7 +2987,7 @@ class GenericReprTest(fixtures.TestBase): class Bar(Foo): def __init__(self, b=3, c=4, **kw): self.c = c - super(Bar, self).__init__(b=b, **kw) + super().__init__(b=b, **kw) eq_( util.generic_repr(Bar(a="a", b="b", c="c"), to_inspect=[Bar, Foo]), @@ -3125,7 +3123,7 @@ class AsInterfaceTest(fixtures.TestBase): def assertAdapted(obj, *methods): assert isinstance(obj, type) - found = set([m for m in dir(obj) if not m.startswith("_")]) + found = {m for m in dir(obj) if not m.startswith("_")} for method in methods: assert method in found found.remove(method) @@ -3163,7 +3161,7 @@ class AsInterfaceTest(fixtures.TestBase): class TestClassHierarchy(fixtures.TestBase): def test_object(self): - eq_(set(util.class_hierarchy(object)), set((object,))) + eq_(set(util.class_hierarchy(object)), {object}) def test_single(self): class A: @@ -3172,14 +3170,14 @@ class TestClassHierarchy(fixtures.TestBase): class B: pass - eq_(set(util.class_hierarchy(A)), set((A, object))) - eq_(set(util.class_hierarchy(B)), set((B, object))) + eq_(set(util.class_hierarchy(A)), {A, object}) + eq_(set(util.class_hierarchy(B)), {B, object}) class C(A, B): pass - eq_(set(util.class_hierarchy(A)), set((A, B, C, object))) - eq_(set(util.class_hierarchy(B)), set((A, B, C, object))) + eq_(set(util.class_hierarchy(A)), {A, B, C, object}) + eq_(set(util.class_hierarchy(B)), {A, B, C, object}) class TestClassProperty(fixtures.TestBase): @@ -3190,7 +3188,7 @@ class TestClassProperty(fixtures.TestBase): class B(A): @classproperty def something(cls): - d = dict(super(B, cls).something) + d = dict(super().something) d.update({"bazz": 2}) return d @@ -3319,7 +3317,7 @@ class BackslashReplaceTest(fixtures.TestBase): def test_utf8_to_utf8(self): eq_( compat.decode_backslashreplace( - "some message méil".encode("utf-8"), "utf-8" + "some message méil".encode(), "utf-8" ), "some message méil", ) diff --git a/test/base/test_warnings.py b/test/base/test_warnings.py index e951fcafcd..ee286a7bc9 100644 --- a/test/base/test_warnings.py +++ b/test/base/test_warnings.py @@ -42,7 +42,7 @@ class WarnDeprecatedLimitedTest(fixtures.TestBase): class ClsWarningTest(fixtures.TestBase): @testing.fixture def dep_cls_fixture(self): - class Connectable(object): + class Connectable: """a docstring""" some_member = "foo" @@ -63,7 +63,7 @@ class ClsWarningTest(fixtures.TestBase): import inspect - class PlainClass(object): + class PlainClass: some_member = "bar" pc_keys = dict(inspect.getmembers(PlainClass())) diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index b575595ac2..00bbc2af45 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -1,4 +1,3 @@ -# -*- encoding: utf-8 from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import Computed diff --git a/test/dialect/mssql/test_deprecations.py b/test/dialect/mssql/test_deprecations.py index 972ce413ba..019712376b 100644 --- a/test/dialect/mssql/test_deprecations.py +++ b/test/dialect/mssql/test_deprecations.py @@ -1,4 +1,3 @@ -# -*- encoding: utf-8 from unittest.mock import Mock from sqlalchemy import Column diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index d19e591b49..6b895e3f1c 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -1,5 +1,3 @@ -# -*- encoding: utf-8 - from decimal import Decimal import re from unittest.mock import Mock diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index 29bf4c812e..b65e274455 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -1,4 +1,3 @@ -# -*- encoding: utf-8 import decimal from sqlalchemy import and_ diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index f682538b37..1716d68e39 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -1,4 +1,3 @@ -# -*- encoding: utf-8 import datetime import decimal import random @@ -514,7 +513,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): m2 = MetaData() t2 = Table("t", m2, autoload_with=connection) - eq_(set(list(t2.indexes)[0].columns), set([t2.c["x"], t2.c.y])) + eq_(set(list(t2.indexes)[0].columns), {t2.c["x"], t2.c.y}) def test_indexes_cols_with_commas(self, metadata, connection): @@ -530,7 +529,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): m2 = MetaData() t2 = Table("t", m2, autoload_with=connection) - eq_(set(list(t2.indexes)[0].columns), set([t2.c["x, col"], t2.c.y])) + eq_(set(list(t2.indexes)[0].columns), {t2.c["x, col"], t2.c.y}) def test_indexes_cols_with_spaces(self, metadata, connection): @@ -546,7 +545,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): m2 = MetaData() t2 = Table("t", m2, autoload_with=connection) - eq_(set(list(t2.indexes)[0].columns), set([t2.c["x col"], t2.c.y])) + eq_(set(list(t2.indexes)[0].columns), {t2.c["x col"], t2.c.y}) def test_indexes_with_filtered(self, metadata, connection): diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index eb14cb30f7..867e422020 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -1,4 +1,3 @@ -# -*- encoding: utf-8 import codecs import datetime import decimal diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index a4a0b24e49..414f73ad76 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -1,5 +1,3 @@ -# coding: utf-8 - from sqlalchemy import BLOB from sqlalchemy import BOOLEAN from sqlalchemy import Boolean diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index d79f2629fa..ed0fc6faca 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -1,5 +1,3 @@ -# coding: utf-8 - import datetime from sqlalchemy import bindparam diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index f56cd98aa3..0ce3611826 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -1,5 +1,3 @@ -# coding: utf-8 - from sqlalchemy import all_ from sqlalchemy import and_ from sqlalchemy import any_ diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index 8f093f1344..5e582a4923 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -1,5 +1,3 @@ -# coding: utf-8 - import re from sqlalchemy import BigInteger @@ -875,10 +873,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): # MySQL converts unique constraints into unique indexes. # separately we get both - indexes = dict((i["name"], i) for i in insp.get_indexes("mysql_uc")) - constraints = set( + indexes = {i["name"]: i for i in insp.get_indexes("mysql_uc")} + constraints = { i["name"] for i in insp.get_unique_constraints("mysql_uc") - ) + } self.assert_("uc_a" in indexes) self.assert_(indexes["uc_a"]["unique"]) @@ -888,8 +886,8 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): # more "official" MySQL construct reflected = Table("mysql_uc", MetaData(), autoload_with=testing.db) - indexes = dict((i.name, i) for i in reflected.indexes) - constraints = set(uc.name for uc in reflected.constraints) + indexes = {i.name: i for i in reflected.indexes} + constraints = {uc.name for uc in reflected.constraints} self.assert_("uc_a" in indexes) self.assert_(indexes["uc_a"].unique) @@ -1259,10 +1257,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): m.create_all(connection) eq_( - dict( - (rec["name"], rec) + { + rec["name"]: rec for rec in inspect(connection).get_foreign_keys("t2") - ), + }, { "cap_t1id_fk": { "name": "cap_t1id_fk", diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py index 76e800e86d..eca1051e21 100644 --- a/test/dialect/mysql/test_types.py +++ b/test/dialect/mysql/test_types.py @@ -1,4 +1,3 @@ -# coding: utf-8 from collections import OrderedDict import datetime import decimal @@ -990,14 +989,14 @@ class EnumSetTest( t.insert(), [ {"id": 1, "data": set()}, - {"id": 2, "data": set([""])}, - {"id": 3, "data": set(["a", ""])}, - {"id": 4, "data": set(["b"])}, + {"id": 2, "data": {""}}, + {"id": 3, "data": {"a", ""}}, + {"id": 4, "data": {"b"}}, ], ) eq_( connection.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, set()), (2, set()), (3, set(["a"])), (4, set(["b"]))], + [(1, set()), (2, set()), (3, {"a"}), (4, {"b"})], ) def test_bitwise_required_for_empty(self): @@ -1023,18 +1022,18 @@ class EnumSetTest( t.insert(), [ {"id": 1, "data": set()}, - {"id": 2, "data": set([""])}, - {"id": 3, "data": set(["a", ""])}, - {"id": 4, "data": set(["b"])}, + {"id": 2, "data": {""}}, + {"id": 3, "data": {"a", ""}}, + {"id": 4, "data": {"b"}}, ], ) eq_( connection.execute(t.select().order_by(t.c.id)).fetchall(), [ (1, set()), - (2, set([""])), - (3, set(["a", ""])), - (4, set(["b"])), + (2, {""}), + (3, {"a", ""}), + (4, {"b"}), ], ) @@ -1052,18 +1051,18 @@ class EnumSetTest( expected = [ ( - set(["a"]), - set(["a"]), - set(["a"]), - set(["'a'"]), - set(["a", "b"]), + {"a"}, + {"a"}, + {"a"}, + {"'a'"}, + {"a", "b"}, ), ( - set(["b"]), - set(["b"]), - set(["b"]), - set(["b"]), - set(["a", "b"]), + {"b"}, + {"b"}, + {"b"}, + {"b"}, + {"a", "b"}, ), ] res = connection.execute(set_table.select()).fetchall() @@ -1079,13 +1078,11 @@ class EnumSetTest( ) set_table.create(connection) - connection.execute( - set_table.insert(), {"data": set(["réveillé", "drôle"])} - ) + connection.execute(set_table.insert(), {"data": {"réveillé", "drôle"}}) row = connection.execute(set_table.select()).first() - eq_(row, (1, set(["réveillé", "drôle"]))) + eq_(row, (1, {"réveillé", "drôle"})) def test_int_roundtrip(self, metadata, connection): set_table = self._set_fixture_one(metadata) @@ -1097,11 +1094,11 @@ class EnumSetTest( eq_( res, ( - set(["a"]), - set(["b"]), - set(["a", "b"]), - set(["'a'", "b"]), - set([]), + {"a"}, + {"b"}, + {"a", "b"}, + {"'a'", "b"}, + set(), ), ) @@ -1129,24 +1126,24 @@ class EnumSetTest( connection.execute(table.delete()) roundtrip([None, None, None], [None] * 3) - roundtrip(["", "", ""], [set([])] * 3) - roundtrip([set(["dq"]), set(["a"]), set(["5"])]) - roundtrip(["dq", "a", "5"], [set(["dq"]), set(["a"]), set(["5"])]) - roundtrip([1, 1, 1], [set(["dq"]), set(["a"]), set(["5"])]) - roundtrip([set(["dq", "sq"]), None, set(["9", "5", "7"])]) + roundtrip(["", "", ""], [set()] * 3) + roundtrip([{"dq"}, {"a"}, {"5"}]) + roundtrip(["dq", "a", "5"], [{"dq"}, {"a"}, {"5"}]) + roundtrip([1, 1, 1], [{"dq"}, {"a"}, {"5"}]) + roundtrip([{"dq", "sq"}, None, {"9", "5", "7"}]) connection.execute( set_table.insert(), [ - {"s3": set(["5"])}, - {"s3": set(["5", "7"])}, - {"s3": set(["5", "7", "9"])}, - {"s3": set(["7", "9"])}, + {"s3": {"5"}}, + {"s3": {"5", "7"}}, + {"s3": {"5", "7", "9"}}, + {"s3": {"7", "9"}}, ], ) rows = connection.execute( select(set_table.c.s3).where( - set_table.c.s3.in_([set(["5"]), ["5", "7"]]) + set_table.c.s3.in_([{"5"}, ["5", "7"]]) ) ).fetchall() diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 8981e74e8c..dff8584e33 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1,4 +1,3 @@ -# coding: utf-8 from sqlalchemy import and_ from sqlalchemy import bindparam from sqlalchemy import cast diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 9c2496b05d..4370992e82 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -1,5 +1,3 @@ -# coding: utf-8 - from multiprocessing import get_context import re from unittest import mock diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index 901db9f4e8..60a05a6b69 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -1,6 +1,3 @@ -# coding: utf-8 - - from sqlalchemy import CHAR from sqlalchemy import Double from sqlalchemy import exc @@ -426,7 +423,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): set(insp.get_table_names()).intersection( ["my_table", "foo_table"] ), - set(["my_table", "foo_table"]), + {"my_table", "foo_table"}, ) def test_reflect_system_table(self): @@ -469,7 +466,7 @@ class DontReflectIOTTest(fixtures.TestBase): def test_reflect_all(self, connection): m = MetaData() m.reflect(connection) - eq_(set(t.name for t in m.tables.values()), set(["admin_docindex"])) + eq_({t.name for t in m.tables.values()}, {"admin_docindex"}) def all_tables_compression_missing(): @@ -924,12 +921,10 @@ class RoundTripIndexTest(fixtures.TestBase): # make a dictionary of the reflected objects: - reflected = dict( - [ - (obj_definition(i), i) - for i in reflectedtable.indexes | reflectedtable.constraints - ] - ) + reflected = { + obj_definition(i): i + for i in reflectedtable.indexes | reflectedtable.constraints + } # assert we got primary key constraint and its name, Error # if not in dict diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 8c78fe85b1..2ba42f5840 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1,6 +1,3 @@ -# coding: utf-8 - - import datetime import decimal import os diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 338d0da4ea..431cd7ded1 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,4 +1,3 @@ -# coding: utf-8 from sqlalchemy import and_ from sqlalchemy import BigInteger from sqlalchemy import bindparam @@ -2871,7 +2870,7 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): i = i.on_conflict_do_update( constraint=self.excl_constr_anon, set_=dict(name=i.excluded.name), - where=((self.table1.c.name != i.excluded.name)), + where=(self.table1.c.name != i.excluded.name), ) self.assert_compile( i, @@ -2913,7 +2912,7 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): i.on_conflict_do_update( constraint=self.excl_constr_anon, set_=dict(name=i.excluded.name), - where=((self.table1.c.name != i.excluded.name)), + where=(self.table1.c.name != i.excluded.name), ) .returning(literal_column("1")) .cte("i_upsert") diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 27d4a4cf99..e8d9a8eb6f 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1,4 +1,3 @@ -# coding: utf-8 import dataclasses import datetime import logging diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 9c1aaf78e4..3cdad78f0e 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -1,5 +1,3 @@ -# coding: utf-8 - from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import Integer diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 6afc2f7c1b..42ec20743d 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -1,5 +1,3 @@ -# coding: utf-8 - import datetime from sqlalchemy import and_ diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index f0893d822b..88b0b73cca 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -1,5 +1,3 @@ -# coding: utf-8 - import itertools from operator import itemgetter import re @@ -120,7 +118,7 @@ class ForeignTableReflectionTest( table = Table("test_foreigntable", metadata, autoload_with=connection) eq_( set(table.columns.keys()), - set(["id", "data"]), + {"id", "data"}, "Columns of reflected foreign table didn't equal expected columns", ) @@ -286,7 +284,7 @@ class MaterializedViewReflectionTest( table = Table("test_mview", metadata, autoload_with=connection) eq_( set(table.columns.keys()), - set(["id", "data"]), + {"id", "data"}, "Columns of reflected mview didn't equal expected columns", ) @@ -297,24 +295,24 @@ class MaterializedViewReflectionTest( def test_get_view_names(self, inspect_fixture): insp, conn = inspect_fixture - eq_(set(insp.get_view_names()), set(["test_regview"])) + eq_(set(insp.get_view_names()), {"test_regview"}) def test_get_materialized_view_names(self, inspect_fixture): insp, conn = inspect_fixture - eq_(set(insp.get_materialized_view_names()), set(["test_mview"])) + eq_(set(insp.get_materialized_view_names()), {"test_mview"}) def test_get_view_names_reflection_cache_ok(self, connection): insp = inspect(connection) - eq_(set(insp.get_view_names()), set(["test_regview"])) + eq_(set(insp.get_view_names()), {"test_regview"}) eq_( set(insp.get_materialized_view_names()), - set(["test_mview"]), + {"test_mview"}, ) eq_( set(insp.get_view_names()).union( insp.get_materialized_view_names() ), - set(["test_regview", "test_mview"]), + {"test_regview", "test_mview"}, ) def test_get_view_definition(self, connection): @@ -481,7 +479,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): table = Table("testtable", metadata, autoload_with=connection) eq_( set(table.columns.keys()), - set(["question", "answer"]), + {"question", "answer"}, "Columns of reflected table didn't equal expected columns", ) assert isinstance(table.c.answer.type, Integer) @@ -532,7 +530,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) eq_( set(table.columns.keys()), - set(["question", "answer", "anything"]), + {"question", "answer", "anything"}, "Columns of reflected table didn't equal expected columns", ) assert isinstance(table.c.anything.type, Integer) @@ -1081,7 +1079,7 @@ class ReflectionTest( eq_( set(meta2.tables), - set(["test_schema_2.some_other_table", "some_table"]), + {"test_schema_2.some_other_table", "some_table"}, ) meta3 = MetaData() @@ -1093,12 +1091,10 @@ class ReflectionTest( eq_( set(meta3.tables), - set( - [ - "test_schema_2.some_other_table", - "test_schema.some_table", - ] - ), + { + "test_schema_2.some_other_table", + "test_schema.some_table", + }, ) def test_cross_schema_reflection_metadata_uses_schema( @@ -1125,7 +1121,7 @@ class ReflectionTest( eq_( set(meta2.tables), - set(["some_other_table", "test_schema.some_table"]), + {"some_other_table", "test_schema.some_table"}, ) def test_uppercase_lowercase_table(self, metadata, connection): @@ -1881,10 +1877,10 @@ class ReflectionTest( # PostgreSQL will create an implicit index for a unique # constraint. Separately we get both - indexes = set(i["name"] for i in insp.get_indexes("pgsql_uc")) - constraints = set( + indexes = {i["name"] for i in insp.get_indexes("pgsql_uc")} + constraints = { i["name"] for i in insp.get_unique_constraints("pgsql_uc") - ) + } self.assert_("uc_a" in indexes) self.assert_("uc_a" in constraints) @@ -1892,8 +1888,8 @@ class ReflectionTest( # reflection corrects for the dupe reflected = Table("pgsql_uc", MetaData(), autoload_with=connection) - indexes = set(i.name for i in reflected.indexes) - constraints = set(uc.name for uc in reflected.constraints) + indexes = {i.name for i in reflected.indexes} + constraints = {uc.name for uc in reflected.constraints} self.assert_("uc_a" not in indexes) self.assert_("uc_a" in constraints) @@ -1951,10 +1947,10 @@ class ReflectionTest( uc_table.create(connection) - indexes = dict((i["name"], i) for i in insp.get_indexes("pgsql_uc")) - constraints = set( + indexes = {i["name"]: i for i in insp.get_indexes("pgsql_uc")} + constraints = { i["name"] for i in insp.get_unique_constraints("pgsql_uc") - ) + } self.assert_("ix_a" in indexes) assert indexes["ix_a"]["unique"] @@ -1962,8 +1958,8 @@ class ReflectionTest( reflected = Table("pgsql_uc", MetaData(), autoload_with=connection) - indexes = dict((i.name, i) for i in reflected.indexes) - constraints = set(uc.name for uc in reflected.constraints) + indexes = {i.name: i for i in reflected.indexes} + constraints = {uc.name for uc in reflected.constraints} self.assert_("ix_a" in indexes) assert indexes["ix_a"].unique @@ -2005,11 +2001,11 @@ class ReflectionTest( reflected = Table("pgsql_cc", MetaData(), autoload_with=connection) - check_constraints = dict( - (uc.name, uc.sqltext.text) + check_constraints = { + uc.name: uc.sqltext.text for uc in reflected.constraints if isinstance(uc, CheckConstraint) - ) + } eq_( check_constraints, diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 39e7d73172..61de57ed44 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1,4 +1,3 @@ -# coding: utf-8 import datetime import decimal from enum import Enum as _PY_Enum @@ -2235,8 +2234,8 @@ class ArrayRoundTripTest: ) # hashable eq_( - set(row[1] for row in r), - set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]), + {row[1] for row in r}, + {("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))}, ) def test_array_plus_native_enum_create(self, metadata, connection): @@ -2261,8 +2260,8 @@ class ArrayRoundTripTest: t.create(connection) eq_( - set(e["name"] for e in inspect(connection).get_enums()), - set(["my_enum_1", "my_enum_2", "my_enum_3"]), + {e["name"] for e in inspect(connection).get_enums()}, + {"my_enum_1", "my_enum_2", "my_enum_3"}, ) t.drop(connection) eq_(inspect(connection).get_enums(), []) @@ -2686,7 +2685,7 @@ class _ArrayOfEnum(TypeDecorator): return sa.cast(bindvalue, self) def result_processor(self, dialect, coltype): - super_rp = super(_ArrayOfEnum, self).result_processor(dialect, coltype) + super_rp = super().result_processor(dialect, coltype) def handle_raw_string(value): inner = re.match(r"^{(.*)}$", value).group(1) @@ -5253,7 +5252,7 @@ class JSONBTest(JSONTest): ), ) def test_where(self, whereclause_fn, expected): - super(JSONBTest, self).test_where(whereclause_fn, expected) + super().test_where(whereclause_fn, expected) class JSONBRoundTripTest(JSONRoundTripTest): @@ -5263,7 +5262,7 @@ class JSONBRoundTripTest(JSONRoundTripTest): @testing.requires.postgresql_utf8_server_encoding def test_unicode_round_trip(self, connection): - super(JSONBRoundTripTest, self).test_unicode_round_trip(connection) + super().test_unicode_round_trip(connection) @testing.only_on("postgresql >= 12") def test_cast_jsonpath(self, connection): diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 643a56c1b6..5bda6577fe 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -1,5 +1,3 @@ -#!coding: utf-8 - """SQLite-specific tests.""" import datetime import json diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index c1fe3140e4..e99448a26b 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -1,5 +1,3 @@ -# coding: utf-8 - import collections.abc as collections_abc from contextlib import contextmanager from contextlib import nullcontext @@ -507,7 +505,7 @@ class ExecuteTest(fixtures.TablesTest): def test_stmt_exception_bytestring_utf8(self): # uncommon case for Py3K, bytestring object passed # as the error message - message = "some message méil".encode("utf-8") + message = "some message méil".encode() err = tsa.exc.SQLAlchemyError(message) eq_(str(err), "some message méil") @@ -537,7 +535,7 @@ class ExecuteTest(fixtures.TablesTest): eq_(str(err), "('some message', 206)") def test_stmt_exception_str_multi_args_bytestring(self): - message = "some message méil".encode("utf-8") + message = "some message méil".encode() err = tsa.exc.SQLAlchemyError(message, 206) eq_(str(err), str((message, 206))) @@ -2500,60 +2498,52 @@ class EngineEventsTest(fixtures.TestBase): eq_( canary, [ - ("begin", set(["conn"])), + ("begin", {"conn"}), ( "execute", - set( - [ - "conn", - "clauseelement", - "multiparams", - "params", - "execution_options", - ] - ), + { + "conn", + "clauseelement", + "multiparams", + "params", + "execution_options", + }, ), ( "cursor_execute", - set( - [ - "conn", - "cursor", - "executemany", - "statement", - "parameters", - "context", - ] - ), + { + "conn", + "cursor", + "executemany", + "statement", + "parameters", + "context", + }, ), - ("rollback", set(["conn"])), - ("begin", set(["conn"])), + ("rollback", {"conn"}), + ("begin", {"conn"}), ( "execute", - set( - [ - "conn", - "clauseelement", - "multiparams", - "params", - "execution_options", - ] - ), + { + "conn", + "clauseelement", + "multiparams", + "params", + "execution_options", + }, ), ( "cursor_execute", - set( - [ - "conn", - "cursor", - "executemany", - "statement", - "parameters", - "context", - ] - ), + { + "conn", + "cursor", + "executemany", + "statement", + "parameters", + "context", + }, ), - ("commit", set(["conn"])), + ("commit", {"conn"}), ], ) @@ -3383,11 +3373,11 @@ class OnConnectTest(fixtures.TestBase): class SomeDialect(cls_): def initialize(self, connection): - super(SomeDialect, self).initialize(connection) + super().initialize(connection) m1.initialize(connection) def on_connect(self): - oc = super(SomeDialect, self).on_connect() + oc = super().on_connect() def my_on_connect(conn): if oc: @@ -3456,11 +3446,11 @@ class OnConnectTest(fixtures.TestBase): supports_statement_cache = True def initialize(self, connection): - super(SomeDialect, self).initialize(connection) + super().initialize(connection) m1.append("initialize") def on_connect(self): - oc = super(SomeDialect, self).on_connect() + oc = super().on_connect() def my_on_connect(conn): if oc: diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 9d9c3a429d..f267eac779 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -1410,8 +1410,8 @@ class QueuePoolTest(PoolTestBase): # two pooled connections unclosed. eq_( - set([c.close.call_count for c in strong_refs]), - set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0]), + {c.close.call_count for c in strong_refs}, + {1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0}, ) def test_recycle(self): @@ -1693,7 +1693,7 @@ class QueuePoolTest(PoolTestBase): class TrackQueuePool(pool.QueuePool): def __init__(self, *arg, **kw): pools.append(self) - super(TrackQueuePool, self).__init__(*arg, **kw) + super().__init__(*arg, **kw) def creator(): return slow_closing_connection.connect() diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 2aefecaeff..81b85df08d 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -136,10 +136,10 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): ) meta.create_all(connection) meta2 = MetaData() - t1r, t2r, t3r = [ + t1r, t2r, t3r = ( Table(x, meta2, autoload_with=connection) for x in ("t1", "t2", "t3") - ] + ) assert t1r.c.t2id.references(t2r.c.id) assert t1r.c.t3id.references(t3r.c.id) @@ -283,7 +283,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): extend_existing=True, autoload_with=connection, ) - eq_(set(t2.columns.keys()), set(["x", "y", "z", "q", "id"])) + eq_(set(t2.columns.keys()), {"x", "y", "z", "q", "id"}) # this has been the actual behavior, the cols are added together, # however the test wasn't checking this correctly @@ -302,7 +302,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): extend_existing=False, autoload_with=connection, ) - eq_(set(t3.columns.keys()), set(["z"])) + eq_(set(t3.columns.keys()), {"z"}) m4 = MetaData() old_z = Column("z", String, primary_key=True) @@ -318,7 +318,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): autoload_replace=False, autoload_with=connection, ) - eq_(set(t4.columns.keys()), set(["x", "y", "z", "q", "id"])) + eq_(set(t4.columns.keys()), {"x", "y", "z", "q", "id"}) eq_(list(t4.primary_key.columns), [t4.c.z, t4.c.id]) assert t4.c.z is old_z assert t4.c.y is old_y @@ -1117,11 +1117,11 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2 = MetaData() m2.reflect(connection, only=["rt_a", "rt_b"]) - eq_(set(m2.tables.keys()), set(["rt_a", "rt_b"])) + eq_(set(m2.tables.keys()), {"rt_a", "rt_b"}) m3 = MetaData() m3.reflect(connection, only=lambda name, meta: name == "rt_c") - eq_(set(m3.tables.keys()), set(["rt_c"])) + eq_(set(m3.tables.keys()), {"rt_c"}) m4 = MetaData() @@ -1155,7 +1155,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m8_e2 = MetaData() rt_c = Table("rt_c", m8_e2) m8_e2.reflect(connection, extend_existing=True, only=["rt_a", "rt_c"]) - eq_(set(m8_e2.tables.keys()), set(["rt_a", "rt_c"])) + eq_(set(m8_e2.tables.keys()), {"rt_a", "rt_c"}) eq_(rt_c.c.keys(), ["id"]) baseline.drop_all(connection) @@ -1212,16 +1212,16 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): # Make sure indexes are in the order we expect them in tmp = [(idx.name, idx) for idx in t2.indexes] tmp.sort() - r1, r2, r3 = [idx[1] for idx in tmp] + r1, r2, r3 = (idx[1] for idx in tmp) assert r1.name == "idx1" assert r2.name == "idx2" assert r1.unique == True # noqa assert r2.unique == False # noqa assert r3.unique == False # noqa - assert set([t2.c.id]) == set(r1.columns) - assert set([t2.c.name, t2.c.id]) == set(r2.columns) - assert set([t2.c.name]) == set(r3.columns) + assert {t2.c.id} == set(r1.columns) + assert {t2.c.name, t2.c.id} == set(r2.columns) + assert {t2.c.name} == set(r3.columns) @testing.requires.comment_reflection def test_comment_reflection(self, connection, metadata): @@ -1350,23 +1350,19 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2 = MetaData() m2.reflect(connection, views=False) - eq_( - set(m2.tables), set(["users", "email_addresses", "dingalings"]) - ) + eq_(set(m2.tables), {"users", "email_addresses", "dingalings"}) m2 = MetaData() m2.reflect(connection, views=True) eq_( set(m2.tables), - set( - [ - "email_addresses_v", - "users_v", - "users", - "dingalings", - "email_addresses", - ] - ), + { + "email_addresses_v", + "users_v", + "users", + "dingalings", + "email_addresses", + }, ) finally: _drop_views(connection) @@ -1537,7 +1533,7 @@ class SchemaManipulationTest(fixtures.TestBase): addresses.append_constraint(fk) addresses.append_constraint(fk) assert len(addresses.c.user_id.foreign_keys) == 1 - assert addresses.constraints == set([addresses.primary_key, fk]) + assert addresses.constraints == {addresses.primary_key, fk} class UnicodeReflectionTest(fixtures.TablesTest): @@ -1546,7 +1542,7 @@ class UnicodeReflectionTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): - no_multibyte_period = set([("plain", "col_plain", "ix_plain")]) + no_multibyte_period = {("plain", "col_plain", "ix_plain")} no_has_table = [ ( "no_has_table_1", @@ -1628,7 +1624,7 @@ class UnicodeReflectionTest(fixtures.TablesTest): # (others?) expect non-unicode strings in result sets/bind # params - names = set([rec[0] for rec in self.names]) + names = {rec[0] for rec in self.names} reflected = set(inspect(connection).get_table_names()) @@ -1639,7 +1635,7 @@ class UnicodeReflectionTest(fixtures.TablesTest): # explicitly NFC). Maybe this database normalizes NFD # on reflection. - nfc = set([unicodedata.normalize("NFC", n) for n in names]) + nfc = {unicodedata.normalize("NFC", n) for n in names} self.assert_(nfc == names) # Yep. But still ensure that bulk reflection and @@ -1653,9 +1649,7 @@ class UnicodeReflectionTest(fixtures.TablesTest): @testing.requires.unicode_connections def test_get_names(self, connection): inspector = inspect(connection) - names = dict( - (tname, (cname, ixname)) for tname, cname, ixname in self.names - ) + names = {tname: (cname, ixname) for tname, cname, ixname in self.names} for tname in inspector.get_table_names(): assert tname in names eq_( @@ -1710,12 +1704,10 @@ class SchemaTest(fixtures.TestBase): eq_( set(meta2.tables), - set( - [ - "some_other_table", - "%s.some_table" % testing.config.test_schema, - ] - ), + { + "some_other_table", + "%s.some_table" % testing.config.test_schema, + }, ) @testing.requires.schemas @@ -1806,13 +1798,11 @@ class SchemaTest(fixtures.TestBase): m2.reflect(connection) eq_( set(m2.tables), - set( - [ - "%s.dingalings" % testing.config.test_schema, - "%s.users" % testing.config.test_schema, - "%s.email_addresses" % testing.config.test_schema, - ] - ), + { + "%s.dingalings" % testing.config.test_schema, + "%s.users" % testing.config.test_schema, + "%s.email_addresses" % testing.config.test_schema, + }, ) @testing.requires.schemas @@ -1837,8 +1827,8 @@ class SchemaTest(fixtures.TestBase): m3.reflect(connection, schema=testing.config.test_schema) eq_( - set((t.name, t.schema) for t in m2.tables.values()), - set((t.name, t.schema) for t in m3.tables.values()), + {(t.name, t.schema) for t in m2.tables.values()}, + {(t.name, t.schema) for t in m3.tables.values()}, ) @@ -1993,7 +1983,7 @@ class CaseSensitiveTest(fixtures.TablesTest): @testing.fails_if(testing.requires._has_mysql_on_windows) def test_table_names(self, connection): x = inspect(connection).get_table_names() - assert set(["SomeTable", "SomeOtherTable"]).issubset(x) + assert {"SomeTable", "SomeOtherTable"}.issubset(x) def test_reflect_exact_name(self, connection): m = MetaData() @@ -2068,7 +2058,7 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TablesTest): def test_override_key(self, connection): def assertions(table): eq_(table.c.YXZ.name, "x") - eq_(set(table.primary_key), set([table.c.YXZ])) + eq_(set(table.primary_key), {table.c.YXZ}) self._do_test(connection, "x", {"key": "YXZ"}, assertions) @@ -2357,7 +2347,7 @@ class IncludeColsFksTest(AssertsCompiledSQL, fixtures.TestBase): eq_([c.name for c in b2.c], ["x", "q", "p"]) # no FK, whether or not resolve_fks was called - eq_(b2.constraints, set((b2.primary_key,))) + eq_(b2.constraints, {b2.primary_key}) b2a = b2.alias() eq_([c.name for c in b2a.c], ["x", "q", "p"]) diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 9efe080296..dcfb3f8502 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -772,25 +772,21 @@ class ConcreteExtensionConfigTest( [ A( data="a1", - collection=set( - [ - B(data="a1b1", b_data="a1b1"), - C(data="a1b2", c_data="a1c1"), - B(data="a1b2", b_data="a1b2"), - C(data="a1c2", c_data="a1c2"), - ] - ), + collection={ + B(data="a1b1", b_data="a1b1"), + C(data="a1b2", c_data="a1c1"), + B(data="a1b2", b_data="a1b2"), + C(data="a1c2", c_data="a1c2"), + }, ), A( data="a2", - collection=set( - [ - B(data="a2b1", b_data="a2b1"), - C(data="a2c1", c_data="a2c1"), - B(data="a2b2", b_data="a2b2"), - C(data="a2c2", c_data="a2c2"), - ] - ), + collection={ + B(data="a2b1", b_data="a2b1"), + C(data="a2c1", c_data="a2c1"), + B(data="a2b2", b_data="a2b2"), + C(data="a2c2", c_data="a2c2"), + }, ), ] ) @@ -802,14 +798,12 @@ class ConcreteExtensionConfigTest( [ A( data="a2", - collection=set( - [ - B(data="a2b1", b_data="a2b1"), - B(data="a2b2", b_data="a2b2"), - C(data="a2c1", c_data="a2c1"), - C(data="a2c2", c_data="a2c2"), - ] - ), + collection={ + B(data="a2b1", b_data="a2b1"), + B(data="a2b2", b_data="a2b2"), + C(data="a2c1", c_data="a2c1"), + C(data="a2c2", c_data="a2c2"), + }, ) ], ) diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index c3e5b586a6..e143ad1277 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -36,7 +36,7 @@ class DeclarativeReflectionBase(fixtures.TablesTest): class DeferredReflectBase(DeclarativeReflectionBase): def teardown_test(self): - super(DeferredReflectBase, self).teardown_test() + super().teardown_test() _DeferredMapperConfig._configs.clear() diff --git a/test/ext/mypy/plain_files/dynamic_rel.py b/test/ext/mypy/plain_files/dynamic_rel.py index 78bf15f5f2..1766c610c8 100644 --- a/test/ext/mypy/plain_files/dynamic_rel.py +++ b/test/ext/mypy/plain_files/dynamic_rel.py @@ -74,7 +74,7 @@ with Session() as session: # EXPECTED_TYPE: AppenderQuery[Address] reveal_type(u.addresses) - u.addresses = set([Address(), Address()]) + u.addresses = {Address(), Address()} if typing.TYPE_CHECKING: # still an AppenderQuery diff --git a/test/ext/mypy/test_mypy_plugin_py3k.py b/test/ext/mypy/test_mypy_plugin_py3k.py index 5d3388ca6e..3df669912e 100644 --- a/test/ext/mypy/test_mypy_plugin_py3k.py +++ b/test/ext/mypy/test_mypy_plugin_py3k.py @@ -53,13 +53,11 @@ class MypyPluginTest(fixtures.TestBase): @testing.fixture(scope="function") def per_func_cachedir(self): - for item in self._cachedir(): - yield item + yield from self._cachedir() @testing.fixture(scope="class") def cachedir(self): - for item in self._cachedir(): - yield item + yield from self._cachedir() def _cachedir(self): # as of mypy 0.971 i think we need to keep mypy_path empty diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index ffaae7db37..3dcb877460 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -559,7 +559,7 @@ class CustomDictTest(_CollectionOperations): self.assert_(len(p1._children) == 3) self.assert_(len(p1.children) == 3) - self.assert_(set(p1.children) == set(["d", "e", "f"])) + self.assert_(set(p1.children) == {"d", "e", "f"}) del ch p1 = self.roundtrip(p1) @@ -641,9 +641,7 @@ class SetTest(_CollectionOperations): self.assert_(len(p1.children) == 2) self.assert_(len(p1._children) == 2) - self.assert_( - set([o.name for o in p1._children]) == set(["regular", "proxied"]) - ) + self.assert_({o.name for o in p1._children} == {"regular", "proxied"}) ch2 = None for o in p1._children: @@ -655,7 +653,7 @@ class SetTest(_CollectionOperations): self.assert_(len(p1._children) == 1) self.assert_(len(p1.children) == 1) - self.assert_(p1._children == set([ch1])) + self.assert_(p1._children == {ch1}) p1.children.remove("regular") @@ -676,7 +674,7 @@ class SetTest(_CollectionOperations): self.assert_("b" in p1.children) self.assert_("d" not in p1.children) - self.assert_(p1.children == set(["a", "b", "c"])) + self.assert_(p1.children == {"a", "b", "c"}) assert_raises(KeyError, p1.children.remove, "d") @@ -695,15 +693,15 @@ class SetTest(_CollectionOperations): p1.children = ["a", "b", "c"] p1 = self.roundtrip(p1) - self.assert_(p1.children == set(["a", "b", "c"])) + self.assert_(p1.children == {"a", "b", "c"}) p1.children.discard("b") p1 = self.roundtrip(p1) - self.assert_(p1.children == set(["a", "c"])) + self.assert_(p1.children == {"a", "c"}) p1.children.remove("a") p1 = self.roundtrip(p1) - self.assert_(p1.children == set(["c"])) + self.assert_(p1.children == {"c"}) p1._children = set() self.assert_(len(p1.children) == 0) @@ -727,15 +725,15 @@ class SetTest(_CollectionOperations): p1 = Parent("P1") p1.children = ["a", "b", "c"] - control = set(["a", "b", "c"]) + control = {"a", "b", "c"} for other in ( - set(["a", "b", "c"]), - set(["a", "b", "c", "d"]), - set(["a"]), - set(["a", "b"]), - set(["c", "d"]), - set(["e", "f", "g"]), + {"a", "b", "c"}, + {"a", "b", "c", "d"}, + {"a"}, + {"a", "b"}, + {"c", "d"}, + {"e", "f", "g"}, set(), ): @@ -795,12 +793,12 @@ class SetTest(_CollectionOperations): ): for base in (["a", "b", "c"], []): for other in ( - set(["a", "b", "c"]), - set(["a", "b", "c", "d"]), - set(["a"]), - set(["a", "b"]), - set(["c", "d"]), - set(["e", "f", "g"]), + {"a", "b", "c"}, + {"a", "b", "c", "d"}, + {"a"}, + {"a", "b"}, + {"c", "d"}, + {"e", "f", "g"}, set(), ): p = Parent("p") @@ -831,12 +829,12 @@ class SetTest(_CollectionOperations): for op in ("|=", "-=", "&=", "^="): for base in (["a", "b", "c"], []): for other in ( - set(["a", "b", "c"]), - set(["a", "b", "c", "d"]), - set(["a"]), - set(["a", "b"]), - set(["c", "d"]), - set(["e", "f", "g"]), + {"a", "b", "c"}, + {"a", "b", "c", "d"}, + {"a"}, + {"a", "b"}, + {"c", "d"}, + {"e", "f", "g"}, frozenset(["e", "f", "g"]), set(), ): @@ -1408,7 +1406,7 @@ class ReconstitutionTest(fixtures.MappedTest): add_child("p1", "c2") session.flush() p = session.query(Parent).filter_by(name="p1").one() - assert set(p.kids) == set(["c1", "c2"]), p.kids + assert set(p.kids) == {"c1", "c2"}, p.kids def test_copy(self): self.mapper_registry.map_imperatively( @@ -1422,7 +1420,7 @@ class ReconstitutionTest(fixtures.MappedTest): p_copy = copy.copy(p) del p gc_collect() - assert set(p_copy.kids) == set(["c1", "c2"]), p_copy.kids + assert set(p_copy.kids) == {"c1", "c2"}, p_copy.kids def test_pickle_list(self): self.mapper_registry.map_imperatively( @@ -1452,7 +1450,7 @@ class ReconstitutionTest(fixtures.MappedTest): p = Parent("p1") p.kids.update(["c1", "c2"]) r1 = pickle.loads(pickle.dumps(p)) - assert r1.kids == set(["c1", "c2"]) + assert r1.kids == {"c1", "c2"} # can't do this without parent having a cycle # r2 = pickle.loads(pickle.dumps(p.kids)) diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index df60a9a653..b9d07390a2 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -56,7 +56,7 @@ class AutomapTest(fixtures.MappedTest): Address = Base.classes.addresses a1 = Address(email_address="e1") - u1 = User(name="u1", addresses_collection=set([a1])) + u1 = User(name="u1", addresses_collection={a1}) assert a1.user is u1 def test_prepare_w_only(self): @@ -291,7 +291,7 @@ class AutomapTest(fixtures.MappedTest): ) Base.prepare(generate_relationship=_gen_relationship) - assert set(tuple(c[1]) for c in mock.mock_calls).issuperset( + assert {tuple(c[1]) for c in mock.mock_calls}.issuperset( [ (Base, interfaces.MANYTOONE, "nodes"), (Base, interfaces.MANYTOMANY, "keywords_collection"), diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index c9d2dc928f..d502277925 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -263,8 +263,8 @@ class LikeQueryTest(BakedTest): # original query still works eq_( - set([(u.id, u.name) for u in bq(sess).all()]), - set([(8, "ed"), (9, "fred")]), + {(u.id, u.name) for u in bq(sess).all()}, + {(8, "ed"), (9, "fred")}, ) def test_count_with_bindparams(self): diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index 7067d24c16..aa03dabc90 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -40,7 +40,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): inherit_cache = False def __init__(self, arg=None): - super(MyThingy, self).__init__(arg or "MYTHINGY!") + super().__init__(arg or "MYTHINGY!") @compiles(MyThingy) def visit_thingy(thingy, compiler, **kw): @@ -125,7 +125,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): inherit_cache = False def __init__(self): - super(MyThingy, self).__init__("MYTHINGY!") + super().__init__("MYTHINGY!") @compiles(MyThingy) def visit_thingy(thingy, compiler, **kw): diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 667f4bfb08..8913478598 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -473,12 +473,12 @@ class ShardTest: sess = self._fixture_data() eq_( - set(row.temperature for row in sess.query(Report.temperature)), + {row.temperature for row in sess.query(Report.temperature)}, {80.0, 75.0, 85.0}, ) temps = sess.query(Report).all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + eq_({t.temperature for t in temps}, {80.0, 75.0, 85.0}) if legacy: sess.query(Report).filter(Report.temperature >= 80).update( @@ -495,14 +495,14 @@ class ShardTest: # test synchronize session def go(): - eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) + eq_({t.temperature for t in temps}, {86.0, 75.0, 91.0}) self.assert_sql_count( sess._ShardedSession__binds["north_america"], go, 0 ) eq_( - set(row.temperature for row in sess.query(Report.temperature)), + {row.temperature for row in sess.query(Report.temperature)}, {86.0, 75.0, 91.0}, ) @@ -514,7 +514,7 @@ class ShardTest: sess = self._fixture_data() temps = sess.query(Report).all() - eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + eq_({t.temperature for t in temps}, {80.0, 75.0, 85.0}) if legacy: sess.query(Report).filter(Report.temperature >= 80).delete( @@ -537,7 +537,7 @@ class ShardTest: ) eq_( - set(row.temperature for row in sess.query(Report.temperature)), + {row.temperature for row in sess.query(Report.temperature)}, {75.0}, ) diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 252844a487..df69b36af1 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -1260,8 +1260,8 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): from sqlalchemy import literal symbols = ("usd", "gbp", "cad", "eur", "aud") - currency_lookup = dict( - ((currency_from, currency_to), Decimal(str(rate))) + currency_lookup = { + (currency_from, currency_to): Decimal(str(rate)) for currency_to, values in zip( symbols, [ @@ -1273,7 +1273,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): ], ) for currency_from, rate in zip(symbols, values) - ) + } class Amount: def __init__(self, amount, currency): diff --git a/test/ext/test_indexable.py b/test/ext/test_indexable.py index 7a9765efb3..a55a16ac4a 100644 --- a/test/ext/test_indexable.py +++ b/test/ext/test_indexable.py @@ -265,11 +265,11 @@ class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest): class json_property(index_property): def __init__(self, attr_name, index, cast_type): - super(json_property, self).__init__(attr_name, index) + super().__init__(attr_name, index) self.cast_type = cast_type def expr(self, model): - expr = super(json_property, self).expr(model) + expr = super().expr(model) return expr.astext.cast(self.cast_type) class Json(fixtures.ComparableEntity, Base): diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index e117710e75..50dc22e351 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -264,7 +264,7 @@ class _MutableDictTestBase(_MutableDictTestFixture): ValueError, "Attribute 'data' does not accept objects of type", Foo, - data=set([1, 2, 3]), + data={1, 2, 3}, ) def test_in_place_mutation(self): @@ -488,7 +488,7 @@ class _MutableListTestBase(_MutableListTestFixture): ValueError, "Attribute 'data' does not accept objects of type", Foo, - data=set([1, 2, 3]), + data={1, 2, 3}, ) def test_in_place_mutation(self): @@ -780,7 +780,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): def test_clear(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() @@ -792,7 +792,7 @@ class _MutableSetTestBase(_MutableSetTestFixture): def test_pop(self): sess = fixture_session() - f1 = Foo(data=set([1])) + f1 = Foo(data={1}) sess.add(f1) sess.commit() @@ -806,144 +806,144 @@ class _MutableSetTestBase(_MutableSetTestFixture): def test_add(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() f1.data.add(5) sess.commit() - eq_(f1.data, set([1, 2, 5])) + eq_(f1.data, {1, 2, 5}) def test_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data.update(set([2, 5])) + f1.data.update({2, 5}) sess.commit() - eq_(f1.data, set([1, 2, 5])) + eq_(f1.data, {1, 2, 5}) def test_binary_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data |= set([2, 5]) + f1.data |= {2, 5} sess.commit() - eq_(f1.data, set([1, 2, 5])) + eq_(f1.data, {1, 2, 5}) def test_intersection_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data.intersection_update(set([2, 5])) + f1.data.intersection_update({2, 5}) sess.commit() - eq_(f1.data, set([2])) + eq_(f1.data, {2}) def test_binary_intersection_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data &= set([2, 5]) + f1.data &= {2, 5} sess.commit() - eq_(f1.data, set([2])) + eq_(f1.data, {2}) def test_difference_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data.difference_update(set([2, 5])) + f1.data.difference_update({2, 5}) sess.commit() - eq_(f1.data, set([1])) + eq_(f1.data, {1}) def test_operator_difference_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data -= set([2, 5]) + f1.data -= {2, 5} sess.commit() - eq_(f1.data, set([1])) + eq_(f1.data, {1}) def test_symmetric_difference_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data.symmetric_difference_update(set([2, 5])) + f1.data.symmetric_difference_update({2, 5}) sess.commit() - eq_(f1.data, set([1, 5])) + eq_(f1.data, {1, 5}) def test_binary_symmetric_difference_update(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() - f1.data ^= set([2, 5]) + f1.data ^= {2, 5} sess.commit() - eq_(f1.data, set([1, 5])) + eq_(f1.data, {1, 5}) def test_remove(self): sess = fixture_session() - f1 = Foo(data=set([1, 2, 3])) + f1 = Foo(data={1, 2, 3}) sess.add(f1) sess.commit() f1.data.remove(2) sess.commit() - eq_(f1.data, set([1, 3])) + eq_(f1.data, {1, 3}) def test_discard(self): sess = fixture_session() - f1 = Foo(data=set([1, 2, 3])) + f1 = Foo(data={1, 2, 3}) sess.add(f1) sess.commit() f1.data.discard(2) sess.commit() - eq_(f1.data, set([1, 3])) + eq_(f1.data, {1, 3}) f1.data.discard(2) sess.commit() - eq_(f1.data, set([1, 3])) + eq_(f1.data, {1, 3}) def test_pickle_parent(self): sess = fixture_session() - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) sess.add(f1) sess.commit() f1.data @@ -958,24 +958,24 @@ class _MutableSetTestBase(_MutableSetTestFixture): def test_unrelated_flush(self): sess = fixture_session() - f1 = Foo(data=set([1, 2]), unrelated_data="unrelated") + f1 = Foo(data={1, 2}, unrelated_data="unrelated") sess.add(f1) sess.flush() f1.unrelated_data = "unrelated 2" sess.flush() f1.data.add(3) sess.commit() - eq_(f1.data, set([1, 2, 3])) + eq_(f1.data, {1, 2, 3}) def test_copy(self): - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) f1.data = copy.copy(f1.data) - eq_(f1.data, set([1, 2])) + eq_(f1.data, {1, 2}) def test_deepcopy(self): - f1 = Foo(data=set([1, 2])) + f1 = Foo(data={1, 2}) f1.data = copy.deepcopy(f1.data) - eq_(f1.data, set([1, 2])) + eq_(f1.data, {1, 2}) class _MutableNoHashFixture: @@ -1349,9 +1349,7 @@ class CustomMutableAssociationScalarJSONTest( @classmethod def _type_fixture(cls): if not (getattr(cls, "CustomMutableDict")): - MutableDict = super( - CustomMutableAssociationScalarJSONTest, cls - )._type_fixture() + MutableDict = super()._type_fixture() class CustomMutableDict(MutableDict): pass diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index e15ace2eb2..d41e56b82a 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -1,5 +1,3 @@ -# coding: utf-8 - from sqlalchemy import desc from sqlalchemy import ForeignKey from sqlalchemy import func diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 6959d06acd..3dfc598272 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -260,7 +260,7 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): configure_mappers() eq_( Parent.children.property._calculated_foreign_keys, - set([Child.name_upper.property.columns[0]]), + {Child.name_upper.property.columns[0]}, ) def test_class_has_registry_attr(self, registry): @@ -2188,8 +2188,8 @@ class DeclarativeMultiBaseTest( adr_count = Address.id - eq_(set(User.__table__.c.keys()), set(["id", "name"])) - eq_(set(Address.__table__.c.keys()), set(["id", "email", "user_id"])) + eq_(set(User.__table__.c.keys()), {"id", "name"}) + eq_(set(Address.__table__.c.keys()), {"id", "email", "user_id"}) def test_deferred(self): class User(Base, fixtures.ComparableEntity): diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index dcb0a5e749..9829f42333 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -79,7 +79,7 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): } assert class_mapper(Person).version_id_col == "a" - assert class_mapper(Person).include_properties == set(["id", "a", "b"]) + assert class_mapper(Person).include_properties == {"id", "a", "b"} def test_custom_join_condition(self): class Foo(Base): diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 9679588468..95990cea04 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -236,7 +236,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): baz = _column(String(100), nullable=False, index=True) @mapper_registry.mapped - class MyModel(MyMixin, object): + class MyModel(MyMixin): __tablename__ = "test" name = Column(String(100), nullable=False, index=True) @@ -2264,8 +2264,8 @@ class DeclaredAttrTest(DeclarativeTestBase, testing.AssertsCompiledSQL): eq_( asserted, { - "a": set([A.my_attr.property.columns[0]]), - "b": set([B.my_attr.property.columns[0]]), + "a": {A.my_attr.property.columns[0]}, + "b": {B.my_attr.property.columns[0]}, }, ) @@ -2395,7 +2395,7 @@ class AbstractTest(DeclarativeTestBase): __tablename__ = "q" id = Column(Integer, primary_key=True) - eq_(set(Base.metadata.tables), set(["y", "z", "q"])) + eq_(set(Base.metadata.tables), {"y", "z", "q"}) def test_middle_abstract_attributes(self): # test for [ticket:3219] diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index 0dd769be14..cf9a30c198 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -1950,7 +1950,7 @@ class UpdateDeleteTest(fixtures.MappedTest): class RoutingSession(Session): def get_bind(self, **kw): received.append(type(kw["clause"])) - return super(RoutingSession, self).get_bind(**kw) + return super().get_bind(**kw) stmt = stmt.execution_options(synchronize_session=sync_type) @@ -2181,16 +2181,14 @@ class UpdateDeleteFromTest(fixtures.MappedTest): eq_( set(s.query(Document.id, Document.flag)), - set( - [ - (1, True), - (2, None), - (3, None), - (4, True), - (5, True), - (6, None), - ] - ), + { + (1, True), + (2, None), + (3, None), + (4, True), + (5, True), + (6, None), + }, ) @testing.requires.delete_using @@ -2210,7 +2208,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest): eq_( set(s.query(Document.id, Document.flag)), - set([(2, None), (3, None), (6, None)]), + {(2, None), (3, None), (6, None)}, ) def test_no_eval_against_multi_table_criteria(self): @@ -2271,16 +2269,14 @@ class UpdateDeleteFromTest(fixtures.MappedTest): eq_( set(s.query(Document.id, Document.flag)), - set( - [ - (1, True), - (2, None), - (3, None), - (4, True), - (5, True), - (6, None), - ] - ), + { + (1, True), + (2, None), + (3, None), + (4, True), + (5, True), + (6, None), + }, ) @testing.requires.update_where_target_in_subquery @@ -2305,16 +2301,14 @@ class UpdateDeleteFromTest(fixtures.MappedTest): eq_( set(s.query(Document.id, Document.flag)), - set( - [ - (1, True), - (2, False), - (3, False), - (4, True), - (5, True), - (6, False), - ] - ), + { + (1, True), + (2, False), + (3, False), + (4, True), + (5, True), + (6, False), + }, ) @testing.requires.multi_table_update @@ -2616,7 +2610,7 @@ class InheritTest(fixtures.DeclarativeMappedTest): eq_( set(s.query(Person.name, Engineer.engineer_name)), - set([("e1", "e1"), ("e2", "e5"), ("pp1", "pp1")]), + {("e1", "e1"), ("e2", "e5"), ("pp1", "pp1")}, ) @testing.requires.delete_using @@ -2714,7 +2708,7 @@ class InheritTest(fixtures.DeclarativeMappedTest): # delete actually worked eq_( set(s.query(Person.name, Engineer.engineer_name)), - set([("pp1", "pp1"), ("e1", "e1")]), + {("pp1", "pp1"), ("e1", "e1")}, ) @testing.only_on(["mysql", "mariadb"], "Multi table update") @@ -2733,7 +2727,7 @@ class InheritTest(fixtures.DeclarativeMappedTest): eq_( set(s.query(Person.name, Engineer.engineer_name)), - set([("e1", "e1"), ("e22", "e55"), ("pp1", "pp1")]), + {("e1", "e1"), ("e22", "e55"), ("pp1", "pp1")}, ) diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 3096984f17..71592a22c3 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -1621,8 +1621,8 @@ class MultiLevelTest(fixtures.MappedTest): session.add(b) session.add(c) session.flush() - assert set(session.query(Employee).all()) == set([a, b, c]) - assert set(session.query(Engineer).all()) == set([b, c]) + assert set(session.query(Employee).all()) == {a, b, c} + assert set(session.query(Engineer).all()) == {b, c} assert session.query(Manager).all() == [c] diff --git a/test/orm/inheritance/test_concrete.py b/test/orm/inheritance/test_concrete.py index 05d1fb2b20..6b5f4f6fa1 100644 --- a/test/orm/inheritance/test_concrete.py +++ b/test/orm/inheritance/test_concrete.py @@ -192,19 +192,17 @@ class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest): session.add(Engineer("Karina", "knows how to hack")) session.flush() session.expunge_all() - assert set([repr(x) for x in session.query(Employee)]) == set( - [ - "Engineer Karina knows how to hack", - "Manager Sally knows how to manage things", - ] - ) - - assert set([repr(x) for x in session.query(Manager)]) == set( - ["Manager Sally knows how to manage things"] - ) - assert set([repr(x) for x in session.query(Engineer)]) == set( - ["Engineer Karina knows how to hack"] - ) + assert {repr(x) for x in session.query(Employee)} == { + "Engineer Karina knows how to hack", + "Manager Sally knows how to manage things", + } + + assert {repr(x) for x in session.query(Manager)} == { + "Manager Sally knows how to manage things" + } + assert {repr(x) for x in session.query(Engineer)} == { + "Engineer Karina knows how to hack" + } manager = session.query(Manager).one() session.expire(manager, ["manager_data"]) eq_(manager.manager_data, "knows how to manage things") @@ -320,25 +318,21 @@ class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest): repr(session.query(Manager).filter(Manager.name == "Sally").one()) == "Manager Sally knows how to manage things" ) - assert set([repr(x) for x in session.query(Employee).all()]) == set( - [ - "Engineer Jenn knows how to program", - "Manager Sally knows how to manage things", - "Hacker Karina 'Badass' knows how to hack", - ] - ) - assert set([repr(x) for x in session.query(Manager).all()]) == set( - ["Manager Sally knows how to manage things"] - ) - assert set([repr(x) for x in session.query(Engineer).all()]) == set( - [ - "Engineer Jenn knows how to program", - "Hacker Karina 'Badass' knows how to hack", - ] - ) - assert set([repr(x) for x in session.query(Hacker).all()]) == set( - ["Hacker Karina 'Badass' knows how to hack"] - ) + assert {repr(x) for x in session.query(Employee).all()} == { + "Engineer Jenn knows how to program", + "Manager Sally knows how to manage things", + "Hacker Karina 'Badass' knows how to hack", + } + assert {repr(x) for x in session.query(Manager).all()} == { + "Manager Sally knows how to manage things" + } + assert {repr(x) for x in session.query(Engineer).all()} == { + "Engineer Jenn knows how to program", + "Hacker Karina 'Badass' knows how to hack", + } + assert {repr(x) for x in session.query(Hacker).all()} == { + "Hacker Karina 'Badass' knows how to hack" + } def test_multi_level_no_base_w_hybrid(self): Employee, Engineer, Manager = self.classes( @@ -503,25 +497,21 @@ class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest): ) == 3 ) - assert set([repr(x) for x in session.query(Employee)]) == set( - [ - "Engineer Jenn knows how to program", - "Manager Sally knows how to manage things", - "Hacker Karina 'Badass' knows how to hack", - ] - ) - assert set([repr(x) for x in session.query(Manager)]) == set( - ["Manager Sally knows how to manage things"] - ) - assert set([repr(x) for x in session.query(Engineer)]) == set( - [ - "Engineer Jenn knows how to program", - "Hacker Karina 'Badass' knows how to hack", - ] - ) - assert set([repr(x) for x in session.query(Hacker)]) == set( - ["Hacker Karina 'Badass' knows how to hack"] - ) + assert {repr(x) for x in session.query(Employee)} == { + "Engineer Jenn knows how to program", + "Manager Sally knows how to manage things", + "Hacker Karina 'Badass' knows how to hack", + } + assert {repr(x) for x in session.query(Manager)} == { + "Manager Sally knows how to manage things" + } + assert {repr(x) for x in session.query(Engineer)} == { + "Engineer Jenn knows how to program", + "Hacker Karina 'Badass' knows how to hack", + } + assert {repr(x) for x in session.query(Hacker)} == { + "Hacker Karina 'Badass' knows how to hack" + } @testing.fixture def two_pjoin_fixture(self): @@ -851,12 +841,10 @@ class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest): def go(): c2 = session.get(Company, c.id) - assert set([repr(x) for x in c2.employees]) == set( - [ - "Engineer Karina knows how to hack", - "Manager Sally knows how to manage things", - ] - ) + assert {repr(x) for x in c2.employees} == { + "Engineer Karina knows how to hack", + "Manager Sally knows how to manage things", + } self.assert_sql_count(testing.db, go, 2) session.expunge_all() @@ -865,12 +853,10 @@ class ConcreteTest(AssertsCompiledSQL, fixtures.MappedTest): c2 = session.get( Company, c.id, options=[joinedload(Company.employees)] ) - assert set([repr(x) for x in c2.employees]) == set( - [ - "Engineer Karina knows how to hack", - "Manager Sally knows how to manage things", - ] - ) + assert {repr(x) for x in c2.employees} == { + "Engineer Karina knows how to hack", + "Manager Sally knows how to manage things", + } self.assert_sql_count(testing.db, go, 1) diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index 64b4bd9a88..49c25f6b63 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -37,7 +37,7 @@ class _PolymorphicTestBase: @classmethod def setup_mappers(cls): - super(_PolymorphicTestBase, cls).setup_mappers() + super().setup_mappers() global people, engineers, managers, boss global companies, paperwork, machines people, engineers, managers, boss, companies, paperwork, machines = ( @@ -52,7 +52,7 @@ class _PolymorphicTestBase: @classmethod def insert_data(cls, connection): - super(_PolymorphicTestBase, cls).insert_data(connection) + super().insert_data(connection) global all_employees, c1_employees, c2_employees global c1, c2, e1, e2, e3, b1, m1 diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index ad99c2eec0..cdd1e02c70 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -478,7 +478,7 @@ class SelfReferentialJ2JSelfTest(fixtures.MappedTest): def _five_obj_fixture(self): sess = fixture_session() - e1, e2, e3, e4, e5 = [Engineer(name="e%d" % (i + 1)) for i in range(5)] + e1, e2, e3, e4, e5 = (Engineer(name="e%d" % (i + 1)) for i in range(5)) e3.reports_to = e1 e4.reports_to = e2 sess.add_all([e1, e2, e3, e4, e5]) @@ -801,13 +801,13 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): with _aliased_join_warning(r"Child2\(child2\)"): eq_( set(sess.execute(stmt).scalars().unique()), - set([c11, c12, c13]), + {c11, c12, c13}, ) with _aliased_join_warning(r"Child2\(child2\)"): eq_( set(sess.query(Child1, Child2).join(Child1.left_child2)), - set([(c11, c22), (c12, c22), (c13, c23)]), + {(c11, c22), (c12, c22), (c13, c23)}, ) # manual alias test: @@ -817,12 +817,12 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): eq_( set(sess.execute(stmt).scalars().unique()), - set([c11, c12, c13]), + {c11, c12, c13}, ) eq_( set(sess.query(Child1, c2).join(Child1.left_child2.of_type(c2))), - set([(c11, c22), (c12, c22), (c13, c23)]), + {(c11, c22), (c12, c22), (c13, c23)}, ) # test __eq__() on property is annotating correctly @@ -835,7 +835,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): with _aliased_join_warning(r"Child1\(child1\)"): eq_( set(sess.execute(stmt).scalars().unique()), - set([c22]), + {c22}, ) # manual aliased version @@ -847,7 +847,7 @@ class SelfReferentialM2MTest(fixtures.MappedTest, AssertsCompiledSQL): ) eq_( set(sess.execute(stmt).scalars().unique()), - set([c22]), + {c22}, ) # test the same again diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 2384d7e2da..eb8d0c01a4 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -2044,17 +2044,13 @@ class EagerDefaultEvalTest(fixtures.DeclarativeMappedTest): class EagerDefaultEvalTestSubDefaults(EagerDefaultEvalTest): @classmethod def setup_classes(cls): - super(EagerDefaultEvalTestSubDefaults, cls).setup_classes( - include_sub_defaults=True - ) + super().setup_classes(include_sub_defaults=True) class EagerDefaultEvalTestPolymorphic(EagerDefaultEvalTest): @classmethod def setup_classes(cls): - super(EagerDefaultEvalTestPolymorphic, cls).setup_classes( - with_polymorphic="*" - ) + super().setup_classes(with_polymorphic="*") class ColExprTest(AssertsCompiledSQL, fixtures.TestBase): diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index f0a91cf392..58e1ab97b9 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -2612,7 +2612,7 @@ class HistoryTest(fixtures.TestBase): ) ] ), - (set([hi, there]), set(), set()), + ({hi, there}, set(), set()), ) self._commit_someattr(f) eq_( @@ -2624,7 +2624,7 @@ class HistoryTest(fixtures.TestBase): ) ] ), - (set(), set([hi, there]), set()), + (set(), {hi, there}, set()), ) def test_object_collections_mutate(self): diff --git a/test/orm/test_backref_mutations.py b/test/orm/test_backref_mutations.py index 0f10cff248..82e353439b 100644 --- a/test/orm/test_backref_mutations.py +++ b/test/orm/test_backref_mutations.py @@ -767,7 +767,7 @@ class M2MCollectionMoveTest(_fixtures.FixtureTest): ._pending_mutations["keywords"] .added_items ), - set([k2]), + {k2}, ) # because autoflush is off, k2 is still # coming in from pending diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 541b7676e7..1a44b5d23b 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -415,13 +415,13 @@ class CacheKeyTest(fixtures.CacheKeyFixture, _fixtures.FixtureTest): """test for issue discovered in #7394""" @registry.mapped - class User2(object): + class User2: __table__ = self.tables.users name_syn = synonym("name") @registry.mapped - class Address2(object): + class Address2: __table__ = self.tables.addresses name_syn = synonym("email_address") diff --git a/test/orm/test_cascade.py b/test/orm/test_cascade.py index 8baa52f19f..e5710e90e6 100644 --- a/test/orm/test_cascade.py +++ b/test/orm/test_cascade.py @@ -137,9 +137,9 @@ class CascadeArgTest(fixtures.MappedTest): users, addresses = self.tables.users, self.tables.addresses rel = relationship(Address) - eq_(rel.cascade, set(["save-update", "merge"])) + eq_(rel.cascade, {"save-update", "merge"}) rel.cascade = "save-update, merge, expunge" - eq_(rel.cascade, set(["save-update", "merge", "expunge"])) + eq_(rel.cascade, {"save-update", "merge", "expunge"}) self.mapper_registry.map_imperatively( User, users, properties={"addresses": rel} @@ -147,7 +147,7 @@ class CascadeArgTest(fixtures.MappedTest): am = self.mapper_registry.map_imperatively(Address, addresses) configure_mappers() - eq_(rel.cascade, set(["save-update", "merge", "expunge"])) + eq_(rel.cascade, {"save-update", "merge", "expunge"}) assert ("addresses", User) not in am._delete_orphans rel.cascade = "all, delete, delete-orphan" @@ -155,16 +155,14 @@ class CascadeArgTest(fixtures.MappedTest): eq_( rel.cascade, - set( - [ - "delete", - "delete-orphan", - "expunge", - "merge", - "refresh-expire", - "save-update", - ] - ), + { + "delete", + "delete-orphan", + "expunge", + "merge", + "refresh-expire", + "save-update", + }, ) def test_cascade_unicode(self): @@ -172,7 +170,7 @@ class CascadeArgTest(fixtures.MappedTest): rel = relationship(Address) rel.cascade = "save-update, merge, expunge" - eq_(rel.cascade, set(["save-update", "merge", "expunge"])) + eq_(rel.cascade, {"save-update", "merge", "expunge"}) class O2MCascadeDeleteOrphanTest(fixtures.MappedTest): @@ -4176,11 +4174,11 @@ class SubclassCascadeTest(fixtures.DeclarativeMappedTest): state = inspect(obj) it = inspect(Company).cascade_iterator("save-update", state) - eq_(set([rec[0] for rec in it]), set([eng, maven_build, lang])) + eq_({rec[0] for rec in it}, {eng, maven_build, lang}) state = inspect(eng) it = inspect(Employee).cascade_iterator("save-update", state) - eq_(set([rec[0] for rec in it]), set([maven_build, lang])) + eq_({rec[0] for rec in it}, {maven_build, lang}) def test_delete_orphan_round_trip(self): ( diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 1c8bee00f4..517e9e7972 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -488,7 +488,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): obj.attr = like_me self.assert_(obj.attr is not direct) self.assert_(obj.attr is not like_me) - self.assert_(set(obj.attr) == set([e2])) + self.assert_(set(obj.attr) == {e2}) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) @@ -496,13 +496,13 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): real_list = [e3] obj.attr = real_list self.assert_(obj.attr is not real_list) - self.assert_(set(obj.attr) == set([e3])) + self.assert_(set(obj.attr) == {e3}) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) e4 = creator() try: - obj.attr = set([e4]) + obj.attr = {e4} self.assert_(False) except TypeError: self.assert_(e4 not in canary.data) @@ -785,7 +785,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): e = creator() addall(e) - values = set([e, creator(), creator()]) + values = {e, creator(), creator()} direct.update(values) control.update(values) @@ -796,14 +796,14 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): e = creator() addall(e) - values = set([e, creator(), creator()]) + values = {e, creator(), creator()} direct |= values control |= values assert_eq() # cover self-assignment short-circuit - values = set([e, creator(), creator()]) + values = {e, creator(), creator()} obj.attr |= values control |= values assert_eq() @@ -837,12 +837,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): zap() e = creator() addall(creator(), creator()) - values = set([creator()]) + values = {creator()} direct.difference_update(values) control.difference_update(values) assert_eq() - values.update(set([e, creator()])) + values.update({e, creator()}) direct.difference_update(values) control.difference_update(values) assert_eq() @@ -851,17 +851,17 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): zap() e = creator() addall(creator(), creator()) - values = set([creator()]) + values = {creator()} direct -= values control -= values assert_eq() - values.update(set([e, creator()])) + values.update({e, creator()}) direct -= values control -= values assert_eq() - values = set([creator()]) + values = {creator()} obj.attr -= values control -= values assert_eq() @@ -887,7 +887,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control.intersection_update(values) assert_eq() - values.update(set([e, creator()])) + values.update({e, creator()}) direct.intersection_update(values) control.intersection_update(values) assert_eq() @@ -902,12 +902,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control &= values assert_eq() - values.update(set([e, creator()])) + values.update({e, creator()}) direct &= values control &= values assert_eq() - values.update(set([creator()])) + values.update({creator()}) obj.attr &= values control &= values assert_eq() @@ -923,14 +923,14 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): e = creator() addall(e, creator(), creator()) - values = set([e, creator()]) + values = {e, creator()} direct.symmetric_difference_update(values) control.symmetric_difference_update(values) assert_eq() e = creator() addall(e) - values = set([e]) + values = {e} direct.symmetric_difference_update(values) control.symmetric_difference_update(values) assert_eq() @@ -945,14 +945,14 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): e = creator() addall(e, creator(), creator()) - values = set([e, creator()]) + values = {e, creator()} direct ^= values control ^= values assert_eq() e = creator() addall(e) - values = set([e]) + values = {e} direct ^= values control ^= values assert_eq() @@ -962,7 +962,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): control ^= values assert_eq() - values = set([creator()]) + values = {creator()} obj.attr ^= values control ^= values assert_eq() @@ -1005,15 +1005,15 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): obj.attr = like_me self.assert_(obj.attr is not direct) self.assert_(obj.attr is not like_me) - self.assert_(obj.attr == set([e2])) + self.assert_(obj.attr == {e2}) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) e3 = creator() - real_set = set([e3]) + real_set = {e3} obj.attr = real_set self.assert_(obj.attr is not real_set) - self.assert_(obj.attr == set([e3])) + self.assert_(obj.attr == {e3}) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) @@ -1291,14 +1291,14 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): if hasattr(direct, "update"): e = creator() - d = dict([(ee.a, ee) for ee in [e, creator(), creator()]]) + d = {ee.a: ee for ee in [e, creator(), creator()]} addall(e, creator()) direct.update(d) control.update(d) assert_eq() - kw = dict([(ee.a, ee) for ee in [e, creator()]]) + kw = {ee.a: ee for ee in [e, creator()]} direct.update(**kw) control.update(**kw) assert_eq() @@ -1335,9 +1335,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): obj.attr = like_me self.assert_(obj.attr is not direct) self.assert_(obj.attr is not like_me) - self.assert_( - set(collections.collection_adapter(obj.attr)) == set([e2]) - ) + self.assert_(set(collections.collection_adapter(obj.attr)) == {e2}) self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) @@ -1349,7 +1347,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): obj.attr = real_dict self.assert_(obj.attr is not real_dict) self.assert_("keyignored1" not in obj.attr) - eq_(set(collections.collection_adapter(obj.attr)), set([e3])) + eq_(set(collections.collection_adapter(obj.attr)), {e3}) self.assert_(e2 in canary.removed) self.assert_(e3 in canary.added) @@ -1405,7 +1403,7 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): def test_dict_subclass2(self): class MyEasyDict(collections.KeyFuncDict): def __init__(self): - super(MyEasyDict, self).__init__(lambda e: e.a) + super().__init__(lambda e: e.a) self._test_adapter( MyEasyDict, self.dictable_entity, to_set=lambda c: set(c.values()) @@ -1878,7 +1876,7 @@ class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest): p = session.get(Parent, pid) - eq_(set(p.children.keys()), set(["foo", "bar"])) + eq_(set(p.children.keys()), {"foo", "bar"}) cid = p.children["foo"].id collections.collection_adapter(p.children).append_with_event( @@ -1890,7 +1888,7 @@ class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest): p = session.get(Parent, pid) - self.assert_(set(p.children.keys()) == set(["foo", "bar"])) + self.assert_(set(p.children.keys()) == {"foo", "bar"}) self.assert_(p.children["foo"].id != cid) self.assert_( @@ -1964,9 +1962,7 @@ class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest): p = session.get(Parent, pid) - self.assert_( - set(p.children.keys()) == set([("foo", "1"), ("foo", "2")]) - ) + self.assert_(set(p.children.keys()) == {("foo", "1"), ("foo", "2")}) cid = p.children[("foo", "1")].id collections.collection_adapter(p.children).append_with_event( @@ -1978,9 +1974,7 @@ class DictHelpersTest(OrderedDictFixture, fixtures.MappedTest): p = session.get(Parent, pid) - self.assert_( - set(p.children.keys()) == set([("foo", "1"), ("foo", "2")]) - ) + self.assert_(set(p.children.keys()) == {("foo", "1"), ("foo", "2")}) self.assert_(p.children[("foo", "1")].id != cid) self.assert_( @@ -2314,7 +2308,7 @@ class CustomCollectionsTest(fixtures.MappedTest): assert len(list(f.bars)) == 2 strongref = list(f.bars.values()) - existing = set([id(b) for b in strongref]) + existing = {id(b) for b in strongref} col = collections.collection_adapter(f.bars) col.append_with_event(Bar("b")) @@ -2324,7 +2318,7 @@ class CustomCollectionsTest(fixtures.MappedTest): f = sess.get(Foo, f.col1) assert len(list(f.bars)) == 2 - replaced = set([id(b) for b in list(f.bars.values())]) + replaced = {id(b) for b in list(f.bars.values())} ne_(existing, replaced) @testing.combinations("direct", "as_callable", argnames="factory_type") diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 8fd22bdd09..5c2f107f45 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -1031,7 +1031,7 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL): users, properties=util.OrderedDict( [ - ("concat", column_property((users.c.id * 2))), + ("concat", column_property(users.c.id * 2)), ( "count", column_property( diff --git a/test/orm/test_deferred.py b/test/orm/test_deferred.py index 5bd70ca7bd..a8317671c7 100644 --- a/test/orm/test_deferred.py +++ b/test/orm/test_deferred.py @@ -1635,7 +1635,7 @@ class InheritanceTest(_Polymorphic): @classmethod def setup_mappers(cls): - super(InheritanceTest, cls).setup_mappers() + super().setup_mappers() from sqlalchemy import inspect inspect(Company).add_property( diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index c816896cf3..859ccf884f 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -875,7 +875,7 @@ class InstrumentationTest(fixtures.ORMTest): class MyDict(collections.KeyFuncDict): def __init__(self): - super(MyDict, self).__init__(lambda value: "k%d" % value) + super().__init__(lambda value: "k%d" % value) @collection.converter def _convert(self, dictlike): diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index 0af61949cd..df335f0f67 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -440,13 +440,11 @@ class DynamicTest(_DynamicFixture, _fixtures.FixtureTest, AssertsCompiledSQL): # test cancellation of None, replacement with nothing eq_( set(u.addresses.order_by(None)), - set( - [ - Address(email_address="ed@bettyboop.com"), - Address(email_address="ed@lala.com"), - Address(email_address="ed@wood.com"), - ] - ), + { + Address(email_address="ed@bettyboop.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@wood.com"), + }, ) def test_count(self, user_address_fixture): @@ -865,13 +863,11 @@ class WriteOnlyTest( # test cancellation of None, replacement with nothing eq_( set(sess.scalars(u.addresses.select().order_by(None))), - set( - [ - Address(email_address="ed@bettyboop.com"), - Address(email_address="ed@lala.com"), - Address(email_address="ed@wood.com"), - ] - ), + { + Address(email_address="ed@bettyboop.com"), + Address(email_address="ed@lala.com"), + Address(email_address="ed@wood.com"), + }, ) def test_secondary_as_join(self): @@ -1286,8 +1282,8 @@ class _UOWTests: u.addresses.remove(a) eq_( - set(ad for ad, in sess.query(Address.email_address)), - set(["a", "b", "d"]), + {ad for ad, in sess.query(Address.email_address)}, + {"a", "b", "d"}, ) @testing.combinations(True, False, argnames="autoflush") diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 47609daa06..56d2815fa5 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -757,7 +757,7 @@ class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): @classmethod def define_tables(cls, metadata): - super(MapperEventsTest, cls).define_tables(metadata) + super().define_tables(metadata) metadata.tables["users"].append_column( Column("extra", Integer, default=5, onupdate=10) ) @@ -2118,7 +2118,7 @@ class RefreshTest(_fixtures.FixtureTest): sess.commit() sess.query(User).union_all(sess.query(User)).all() - eq_(canary, [("refresh", set(["id", "name"]))]) + eq_(canary, [("refresh", {"id", "name"})]) def test_via_refresh_state(self): User = self.classes.User @@ -2132,7 +2132,7 @@ class RefreshTest(_fixtures.FixtureTest): sess.commit() u1.name - eq_(canary, [("refresh", set(["id", "name"]))]) + eq_(canary, [("refresh", {"id", "name"})]) def test_was_expired(self): User = self.classes.User @@ -2147,7 +2147,7 @@ class RefreshTest(_fixtures.FixtureTest): sess.expire(u1) sess.query(User).first() - eq_(canary, [("refresh", set(["id", "name"]))]) + eq_(canary, [("refresh", {"id", "name"})]) def test_was_expired_via_commit(self): User = self.classes.User @@ -2161,7 +2161,7 @@ class RefreshTest(_fixtures.FixtureTest): sess.commit() sess.query(User).first() - eq_(canary, [("refresh", set(["id", "name"]))]) + eq_(canary, [("refresh", {"id", "name"})]) def test_was_expired_attrs(self): User = self.classes.User @@ -2176,7 +2176,7 @@ class RefreshTest(_fixtures.FixtureTest): sess.expire(u1, ["name"]) sess.query(User).first() - eq_(canary, [("refresh", set(["name"]))]) + eq_(canary, [("refresh", {"name"})]) def test_populate_existing(self): User = self.classes.User diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 30a4c54dd0..85c950876b 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -3587,7 +3587,7 @@ class ExternalColumnsTest(QueryTest): User, users, properties={ - "concat": column_property((users.c.id * 2)), + "concat": column_property(users.c.id * 2), "count": column_property( select(func.count(addresses.c.id)) .where( @@ -3754,7 +3754,7 @@ class ExternalColumnsTest(QueryTest): "addresses": relationship( Address, backref="user", order_by=addresses.c.id ), - "concat": column_property((users.c.id * 2)), + "concat": column_property(users.c.id * 2), "count": column_property( select(func.count(addresses.c.id)) .where( diff --git a/test/orm/test_generative.py b/test/orm/test_generative.py index 7c9876aec2..b17559a967 100644 --- a/test/orm/test_generative.py +++ b/test/orm/test_generative.py @@ -297,7 +297,7 @@ class RelationshipsTest(_fixtures.FixtureTest): .outerjoin(Order.addresses) .filter(sa.or_(Order.id == None, Address.id == 1)) ) # noqa - eq_(set([User(id=7), User(id=8), User(id=10)]), set(q.all())) + eq_({User(id=7), User(id=8), User(id=10)}, set(q.all())) def test_outer_join_count(self): """test the join and outerjoin functions on Query""" @@ -338,7 +338,7 @@ class RelationshipsTest(_fixtures.FixtureTest): .select_from(sel) .filter(sa.or_(Order.id == None, Address.id == 1)) ) # noqa - eq_(set([User(id=7), User(id=8), User(id=10)]), set(q.all())) + eq_({User(id=7), User(id=8), User(id=10)}, set(q.all())) class CaseSensitiveTest(fixtures.MappedTest): diff --git a/test/orm/test_inspect.py b/test/orm/test_inspect.py index c6fc3ace19..8644b36e55 100644 --- a/test/orm/test_inspect.py +++ b/test/orm/test_inspect.py @@ -169,7 +169,7 @@ class TestORMInspection(_fixtures.FixtureTest): rel = inspect(User).relationships eq_(rel.addresses, User.addresses.property) - eq_(set(rel.keys()), set(["orders", "addresses"])) + eq_(set(rel.keys()), {"orders", "addresses"}) def test_insp_relationship_prop(self): User = self.classes.User @@ -284,10 +284,10 @@ class TestORMInspection(_fixtures.FixtureTest): insp = inspect(SomeSubClass) eq_( - dict( - (k, v.extension_type) + { + k: v.extension_type for k, v in list(insp.all_orm_descriptors.items()) - ), + }, { "id": NotExtension.NOT_EXTENSION, "name": NotExtension.NOT_EXTENSION, @@ -330,7 +330,7 @@ class TestORMInspection(_fixtures.FixtureTest): eq_( set(insp.attrs.keys()), - set(["id", "name", "name_syn", "addresses", "orders"]), + {"id", "name", "name_syn", "addresses", "orders"}, ) eq_(insp.attrs.name.value, "ed") eq_(insp.attrs.name.loaded_value, "ed") @@ -416,10 +416,10 @@ class TestORMInspection(_fixtures.FixtureTest): m = self.mapper_registry.map_imperatively(AnonClass, self.tables.users) - eq_(set(inspect(AnonClass).attrs.keys()), set(["id", "name"])) + eq_(set(inspect(AnonClass).attrs.keys()), {"id", "name"}) eq_( set(inspect(AnonClass).all_orm_descriptors.keys()), - set(["id", "name"]), + {"id", "name"}, ) m.add_property("q", column_property(self.tables.users.c.name)) @@ -429,10 +429,10 @@ class TestORMInspection(_fixtures.FixtureTest): AnonClass.foob = hybrid_property(desc) - eq_(set(inspect(AnonClass).attrs.keys()), set(["id", "name", "q"])) + eq_(set(inspect(AnonClass).attrs.keys()), {"id", "name", "q"}) eq_( set(inspect(AnonClass).all_orm_descriptors.keys()), - set(["id", "name", "q", "foob"]), + {"id", "name", "q", "foob"}, ) def _random_names(self): @@ -456,9 +456,9 @@ class TestORMInspection(_fixtures.FixtureTest): names = self._random_names() if base is supercls: - pk_names = set( + pk_names = { random.choice(names) for i in range(random.randint(1, 3)) - ) + } fk_name = random.choice( [name for name in names if name not in pk_names] ) diff --git a/test/orm/test_instrumentation.py b/test/orm/test_instrumentation.py index 99d5498d68..b4ce5b1f2e 100644 --- a/test/orm/test_instrumentation.py +++ b/test/orm/test_instrumentation.py @@ -104,7 +104,7 @@ class InitTest(fixtures.ORMTest): class B(A): def __init__(self): inits.append((B, "__init__")) - super(B, self).__init__() + super().__init__() self.register(B, inits) @@ -128,7 +128,7 @@ class InitTest(fixtures.ORMTest): class B(A): def __init__(self): inits.append((B, "__init__")) - super(B, self).__init__() + super().__init__() A() eq_(inits, [(A, "init", A), (A, "__init__")]) @@ -150,7 +150,7 @@ class InitTest(fixtures.ORMTest): class B(A): def __init__(self): inits.append((B, "__init__")) - super(B, self).__init__() + super().__init__() self.register(B, inits) @@ -196,14 +196,14 @@ class InitTest(fixtures.ORMTest): class B(A): def __init__(self): inits.append((B, "__init__")) - super(B, self).__init__() + super().__init__() self.register(B, inits) class C(B): def __init__(self): inits.append((C, "__init__")) - super(C, self).__init__() + super().__init__() self.register(C, inits) @@ -239,12 +239,12 @@ class InitTest(fixtures.ORMTest): class B(A): def __init__(self): inits.append((B, "__init__")) - super(B, self).__init__() + super().__init__() class C(B): def __init__(self): inits.append((C, "__init__")) - super(C, self).__init__() + super().__init__() self.register(C, inits) @@ -283,7 +283,7 @@ class InitTest(fixtures.ORMTest): class C(B): def __init__(self): inits.append((C, "__init__")) - super(C, self).__init__() + super().__init__() self.register(C, inits) @@ -316,7 +316,7 @@ class InitTest(fixtures.ORMTest): class C(B): def __init__(self): inits.append((C, "__init__")) - super(C, self).__init__() + super().__init__() self.register(C, inits) @@ -656,7 +656,7 @@ class Py3KFunctionInstTest(fixtures.ORMTest): assert_raises(TypeError, cls, "a", "b", c="c") def _kw_only_fixture(self): - class A(object): + class A: def __init__(self, a, *, b, c): self.a = a self.b = b @@ -665,7 +665,7 @@ class Py3KFunctionInstTest(fixtures.ORMTest): return self._instrument(A) def _kw_plus_posn_fixture(self): - class A(object): + class A: def __init__(self, a, *args, b, c): self.a = a self.b = b @@ -674,7 +674,7 @@ class Py3KFunctionInstTest(fixtures.ORMTest): return self._instrument(A) def _kw_opt_fixture(self): - class A(object): + class A: def __init__(self, a, *, b, c="c"): self.a = a self.b = b diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index e069086218..5e869f6b34 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -1170,12 +1170,12 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): configure_mappers() def assert_props(cls, want): - have = set([n for n in dir(cls) if not n.startswith("_")]) + have = {n for n in dir(cls) if not n.startswith("_")} want = set(want) eq_(have, want) def assert_instrumented(cls, want): - have = set([p.key for p in class_mapper(cls).iterate_properties]) + have = {p.key for p in class_mapper(cls).iterate_properties} want = set(want) eq_(have, want) @@ -1979,7 +1979,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class MyFakeProperty(sa.orm.properties.ColumnProperty): def post_instrument_class(self, mapper): - super(MyFakeProperty, self).post_instrument_class(mapper) + super().post_instrument_class(mapper) configure_mappers() self.mapper( @@ -1992,7 +1992,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class MyFakeProperty(sa.orm.properties.ColumnProperty): def post_instrument_class(self, mapper): - super(MyFakeProperty, self).post_instrument_class(mapper) + super().post_instrument_class(mapper) configure_mappers() self.mapper( @@ -2677,7 +2677,7 @@ class IsUserlandTest(fixtures.MappedTest): self._test({"bar": "bat"}) def test_set(self): - self._test(set([6])) + self._test({6}) def test_column(self): self._test_not(self.tables.foo.c.someprop) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index ecca330629..e06daa9a4f 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -101,8 +101,8 @@ class PathTest: assert_paths = [k[1] for k in attr] eq_( - set([p for p in assert_paths]), - set([self._make_path(p) for p in paths]), + {p for p in assert_paths}, + {self._make_path(p) for p in paths}, ) diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index eb94605800..f937802c32 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -297,7 +297,7 @@ class _JoinFixtures: self.selfref, self.selfref, prop=self.relationship, - remote_side=set([self.selfref.c.id]), + remote_side={self.selfref.c.id}, **kw, ) @@ -318,12 +318,10 @@ class _JoinFixtures: self.composite_selfref, self.composite_selfref, prop=self.relationship, - remote_side=set( - [ - self.composite_selfref.c.id, - self.composite_selfref.c.group_id, - ] - ), + remote_side={ + self.composite_selfref.c.id, + self.composite_selfref.c.group_id, + }, **kw, ) @@ -356,7 +354,7 @@ class _JoinFixtures: self.composite_selfref.c.parent_id == self.composite_selfref.c.id, ), - remote_side=set([self.composite_selfref.c.parent_id]), + remote_side={self.composite_selfref.c.parent_id}, **kw, ) @@ -800,7 +798,7 @@ class ColumnCollectionsTest( def test_determine_remote_columns_compound_1(self): joincond = self._join_fixture_compound_expression_1(support_sync=False) - eq_(joincond.remote_columns, set([self.right.c.x, self.right.c.y])) + eq_(joincond.remote_columns, {self.right.c.x, self.right.c.y}) def test_determine_local_remote_compound_1(self): joincond = self._join_fixture_compound_expression_1(support_sync=False) @@ -848,15 +846,15 @@ class ColumnCollectionsTest( def test_determine_remote_columns_compound_2(self): joincond = self._join_fixture_compound_expression_2(support_sync=False) - eq_(joincond.remote_columns, set([self.right.c.x, self.right.c.y])) + eq_(joincond.remote_columns, {self.right.c.x, self.right.c.y}) def test_determine_remote_columns_o2m(self): joincond = self._join_fixture_o2m() - eq_(joincond.remote_columns, set([self.right.c.lid])) + eq_(joincond.remote_columns, {self.right.c.lid}) def test_determine_remote_columns_o2m_selfref(self): joincond = self._join_fixture_o2m_selfref() - eq_(joincond.remote_columns, set([self.selfref.c.sid])) + eq_(joincond.remote_columns, {self.selfref.c.sid}) def test_determine_local_remote_pairs_o2m_composite_selfref(self): joincond = self._join_fixture_o2m_composite_selfref() @@ -915,17 +913,15 @@ class ColumnCollectionsTest( joincond = self._join_fixture_m2o_composite_selfref() eq_( joincond.remote_columns, - set( - [ - self.composite_selfref.c.id, - self.composite_selfref.c.group_id, - ] - ), + { + self.composite_selfref.c.id, + self.composite_selfref.c.group_id, + }, ) def test_determine_remote_columns_m2o(self): joincond = self._join_fixture_m2o() - eq_(joincond.remote_columns, set([self.left.c.id])) + eq_(joincond.remote_columns, {self.left.c.id}) def test_determine_local_remote_pairs_o2m(self): joincond = self._join_fixture_o2m() @@ -978,23 +974,23 @@ class ColumnCollectionsTest( def test_determine_local_columns_m2m_backref(self): j1, j2 = self._join_fixture_m2m_backref() - eq_(j1.local_columns, set([self.m2mleft.c.id])) - eq_(j2.local_columns, set([self.m2mright.c.id])) + eq_(j1.local_columns, {self.m2mleft.c.id}) + eq_(j2.local_columns, {self.m2mright.c.id}) def test_determine_remote_columns_m2m_backref(self): j1, j2 = self._join_fixture_m2m_backref() eq_( j1.remote_columns, - set([self.m2msecondary.c.lid, self.m2msecondary.c.rid]), + {self.m2msecondary.c.lid, self.m2msecondary.c.rid}, ) eq_( j2.remote_columns, - set([self.m2msecondary.c.lid, self.m2msecondary.c.rid]), + {self.m2msecondary.c.lid, self.m2msecondary.c.rid}, ) def test_determine_remote_columns_m2o_selfref(self): joincond = self._join_fixture_m2o_selfref() - eq_(joincond.remote_columns, set([self.selfref.c.id])) + eq_(joincond.remote_columns, {self.selfref.c.id}) def test_determine_local_remote_cols_three_tab_viewonly(self): joincond = self._join_fixture_overlapping_three_tables() @@ -1004,7 +1000,7 @@ class ColumnCollectionsTest( ) eq_( joincond.remote_columns, - set([self.three_tab_b.c.id, self.three_tab_b.c.aid]), + {self.three_tab_b.c.id, self.three_tab_b.c.aid}, ) def test_determine_local_remote_overlapping_composite_fks(self): @@ -1033,7 +1029,7 @@ class ColumnCollectionsTest( ) eq_( joincond.remote_columns, - set([self.base.c.flag, self.sub_w_sub_rel.c.sub_id]), + {self.base.c.flag, self.sub_w_sub_rel.c.sub_id}, ) diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 8d27742c33..0cbcc01f35 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -1398,25 +1398,21 @@ class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): employee_t = self.tables.employee_t eq_( set(Employee.employees.property.local_remote_pairs), - set( - [ - (employee_t.c.company_id, employee_t.c.company_id), - (employee_t.c.emp_id, employee_t.c.reports_to_id), - ] - ), + { + (employee_t.c.company_id, employee_t.c.company_id), + (employee_t.c.emp_id, employee_t.c.reports_to_id), + }, ) eq_( Employee.employees.property.remote_side, - set([employee_t.c.company_id, employee_t.c.reports_to_id]), + {employee_t.c.company_id, employee_t.c.reports_to_id}, ) eq_( set(Employee.reports_to.property.local_remote_pairs), - set( - [ - (employee_t.c.company_id, employee_t.c.company_id), - (employee_t.c.reports_to_id, employee_t.c.emp_id), - ] - ), + { + (employee_t.c.company_id, employee_t.c.company_id), + (employee_t.c.reports_to_id, employee_t.c.emp_id), + }, ) def _setup_data(self, sess): @@ -3301,8 +3297,8 @@ class ViewOnlyOverlappingNames(fixtures.MappedTest): sess.expunge_all() c1 = sess.get(C1, c1.id) - assert set([x.id for x in c1.t2s]) == set([c2a.id, c2b.id]) - assert set([x.id for x in c1.t2_view]) == set([c2b.id]) + assert {x.id for x in c1.t2s} == {c2a.id, c2b.id} + assert {x.id for x in c1.t2_view} == {c2b.id} class ViewOnlySyncBackref(fixtures.MappedTest): @@ -3565,8 +3561,8 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): sess.expunge_all() c1 = sess.get(C1, c1.t1id) - assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id]) - assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id]) + assert {x.t2id for x in c1.t2s} == {c2a.t2id, c2b.t2id} + assert {x.t2id for x in c1.t2_view} == {c2b.t2id} class ViewOnlyLocalRemoteM2M(fixtures.TestBase): @@ -4524,7 +4520,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): self.mapper_registry.map_imperatively(B, b) sa.orm.configure_mappers() assert A.bs.property.primaryjoin.compare(a.c.id == b.c.aid_1) - eq_(A.bs.property._calculated_foreign_keys, set([b.c.aid_1])) + eq_(A.bs.property._calculated_foreign_keys, {b.c.aid_1}) def test_with_pj_o2m(self): A, B = self.classes.A, self.classes.B @@ -4539,7 +4535,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): self.mapper_registry.map_imperatively(B, b) sa.orm.configure_mappers() assert A.bs.property.primaryjoin.compare(a.c.id == b.c.aid_1) - eq_(A.bs.property._calculated_foreign_keys, set([b.c.aid_1])) + eq_(A.bs.property._calculated_foreign_keys, {b.c.aid_1}) def test_with_annotated_pj_o2m(self): A, B = self.classes.A, self.classes.B @@ -4554,7 +4550,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): self.mapper_registry.map_imperatively(B, b) sa.orm.configure_mappers() assert A.bs.property.primaryjoin.compare(a.c.id == b.c.aid_1) - eq_(A.bs.property._calculated_foreign_keys, set([b.c.aid_1])) + eq_(A.bs.property._calculated_foreign_keys, {b.c.aid_1}) def test_no_fks_m2m(self): A, B = self.classes.A, self.classes.B @@ -5456,8 +5452,8 @@ class InvalidRelationshipEscalationTestM2M( ) self.mapper_registry.map_imperatively(Bar, bars) sa.orm.configure_mappers() - eq_(Foo.bars.property._join_condition.local_columns, set([foos.c.id])) - eq_(Bar.foos.property._join_condition.local_columns, set([bars.c.id])) + eq_(Foo.bars.property._join_condition.local_columns, {foos.c.id}) + eq_(Bar.foos.property._join_condition.local_columns, {bars.c.id}) def test_bad_primaryjoin(self): foobars_with_fks, bars, Bar, foobars, Foo, foos = ( @@ -6593,7 +6589,7 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest): s = fixture_session() with assert_engine(testing.db) as asserter_: - eq_(set(id_ for id_, in s.query(A.id).filter(A.bs.any())), {1, 2}) + eq_({id_ for id_, in s.query(A.id).filter(A.bs.any())}, {1, 2}) asserter_.assert_( CompiledSQL( diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 33e66d52f6..22e1178aa7 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -195,16 +195,12 @@ class ScopedSessionTest(fixtures.MappedTest): elif style == "style3": # py2k style def get_bind(self, mapper=None, *args, **kwargs): - return super(MySession, self).get_bind( - mapper, *args, **kwargs - ) + return super().get_bind(mapper, *args, **kwargs) elif style == "style4": # py2k style def get_bind(self, mapper=None, **kwargs): - return super(MySession, self).get_bind( - mapper=mapper, **kwargs - ) + return super().get_bind(mapper=mapper, **kwargs) s1 = MySession(testing.db) is_(s1.get_bind(), testing.db) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index d13c29ea64..5ba180f5f0 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -2054,9 +2054,7 @@ class DisposedStates(fixtures.MappedTest): class SessionInterface(fixtures.MappedTest): """Bogus args to Session methods produce actionable exceptions.""" - _class_methods = set( - ("connection", "execute", "get_bind", "scalar", "scalars") - ) + _class_methods = {"connection", "execute", "get_bind", "scalar", "scalars"} def _public_session_methods(self): Session = sa.orm.session.Session @@ -2140,13 +2138,11 @@ class SessionInterface(fixtures.MappedTest): instance_methods = ( self._public_session_methods() - self._class_methods - - set( - [ - "bulk_update_mappings", - "bulk_insert_mappings", - "bulk_save_objects", - ] - ) + - { + "bulk_update_mappings", + "bulk_insert_mappings", + "bulk_save_objects", + } ) eq_( diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index c40cbfd579..2f08080ada 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -360,7 +360,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): sess.commit() - eq_(set(sess.query(User).all()), set([u2])) + eq_(set(sess.query(User).all()), {u2}) sess.rollback() sess.begin() @@ -371,7 +371,7 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): n1.commit() # commit the nested transaction sess.rollback() - eq_(set(sess.query(User).all()), set([u2])) + eq_(set(sess.query(User).all()), {u2}) sess.close() diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index f9eff448b1..79d4adacf2 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1,4 +1,3 @@ -# coding: utf-8 """Tests unitofwork operations.""" import datetime @@ -2565,7 +2564,7 @@ class ManyToManyTest(_fixtures.FixtureTest): session = fixture_session() objects = [] - _keywords = dict([(k.name, k) for k in session.query(Keyword)]) + _keywords = {k.name: k for k in session.query(Keyword)} for elem in data[1:]: item = Item(description=elem["description"]) @@ -2797,7 +2796,7 @@ class ManyToManyTest(_fixtures.FixtureTest): session = fixture_session() def fixture(): - _kw = dict([(k.name, k) for k in session.query(Keyword)]) + _kw = {k.name: k for k in session.query(Keyword)} for n in ( "big", "green", @@ -3232,7 +3231,7 @@ class RowSwitchTest(fixtures.MappedTest): t5t7.select(), ) ), - set([(1, 1), (1, 2)]), + {(1, 1), (1, 2)}, ) eq_( list( @@ -3513,7 +3512,7 @@ class NoRowInsertedTest(fixtures.TestBase): @testing.fixture def null_server_default_fixture(self, registry, connection): @registry.mapped - class MyClass(object): + class MyClass: __tablename__ = "my_table" id = Column(Integer, primary_key=True) @@ -3676,7 +3675,7 @@ class EnsurePKSortableTest(fixtures.MappedTest): class T3(cls.Basic): def __str__(self): - return "T3(id={})".format(self.id) + return f"T3(id={self.id})" @classmethod def setup_mappers(cls): diff --git a/test/orm/test_validators.py b/test/orm/test_validators.py index adfb6cb74d..659ee3ca0c 100644 --- a/test/orm/test_validators.py +++ b/test/orm/test_validators.py @@ -106,7 +106,7 @@ class ValidatorTest(_fixtures.FixtureTest): self.mapper_registry.map_imperatively(Address, addresses) eq_( - dict((k, v[0].__name__) for k, v in list(u_m.validators.items())), + {k: v[0].__name__ for k, v in list(u_m.validators.items())}, {"name": "validate_name", "addresses": "validate_address"}, ) diff --git a/test/perf/orm2010.py b/test/perf/orm2010.py index de467cfd73..61b4e9b89c 100644 --- a/test/perf/orm2010.py +++ b/test/perf/orm2010.py @@ -143,9 +143,7 @@ def run_with_profile(runsnake=False, dump=False): ) stats = pstats.Stats(filename) - counts_by_methname = dict( - (key[2], stats.stats[key][0]) for key in stats.stats - ) + counts_by_methname = {key[2]: stats.stats[key][0] for key in stats.stats} print("SQLA Version: %s" % __version__) print("Total calls %d" % stats.total_calls) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 30ca5c5699..f18c79c7b2 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1261,7 +1261,7 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): also included in the fixtures above. """ - need = set( + need = { cls for cls in class_hierarchy(ClauseElement) if issubclass(cls, (ColumnElement, Selectable, LambdaElement)) @@ -1275,7 +1275,7 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): and "compiler" not in cls.__module__ and "crud" not in cls.__module__ and "dialects" not in cls.__module__ # TODO: dialects? - ).difference({ColumnElement, UnaryExpression}) + }.difference({ColumnElement, UnaryExpression}) for fixture in self.fixtures + self.dont_compare_values_fixtures: case_a = fixture() diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4eea117957..e5a149c49e 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -1,5 +1,3 @@ -#! coding:utf-8 - """ compiler tests. @@ -2099,10 +2097,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_order_by_clause(self): class CustomCompiler(PGCompiler): def order_by_clause(self, select, **kw): - return ( - super(CustomCompiler, self).order_by_clause(select, **kw) - + " CUSTOMIZED" - ) + return super().order_by_clause(select, **kw) + " CUSTOMIZED" class CustomDialect(PGDialect): name = "custom" @@ -2119,10 +2114,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): def test_custom_group_by_clause(self): class CustomCompiler(PGCompiler): def group_by_clause(self, select, **kw): - return ( - super(CustomCompiler, self).group_by_clause(select, **kw) - + " CUSTOMIZED" - ) + return super().group_by_clause(select, **kw) + " CUSTOMIZED" class CustomDialect(PGDialect): name = "custom" @@ -3777,9 +3769,7 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): class MyCompiler(compiler.SQLCompiler): def bindparam_string(self, name, **kw): kw["escaped_from"] = name - return super(MyCompiler, self).bindparam_string( - '"%s"' % name, **kw - ) + return super().bindparam_string('"%s"' % name, **kw) dialect = default.DefaultDialect() dialect.statement_compiler = MyCompiler @@ -3863,7 +3853,7 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): total_params = 100000 in_clause = [":in%d" % i for i in range(total_params)] - params = dict(("in%d" % i, i) for i in range(total_params)) + params = {"in%d" % i: i for i in range(total_params)} t = text("text clause %s" % ", ".join(in_clause)) eq_(len(t.bindparams), total_params) c = t.compile() @@ -6590,7 +6580,7 @@ class ResultMapTest(fixtures.TestBase): comp = stmt.compile() eq_( set(comp._create_result_map()), - set(["t1_1_b", "t1_1_a", "t1_a", "t1_b"]), + {"t1_1_b", "t1_1_a", "t1_a", "t1_b"}, ) is_(comp._create_result_map()["t1_a"][1][2], t1.c.a) @@ -6643,14 +6633,12 @@ class ResultMapTest(fixtures.TestBase): if stmt is stmt2.element: with self._nested_result() as nested: contexts[stmt2.element] = nested - text = super(MyCompiler, self).visit_select( + text = super().visit_select( stmt2.element, ) self._add_to_result_map("k1", "k1", (1, 2, 3), int_) else: - text = super(MyCompiler, self).visit_select( - stmt, *arg, **kw - ) + text = super().visit_select(stmt, *arg, **kw) self._add_to_result_map("k2", "k2", (3, 4, 5), int_) return text diff --git a/test/sql/test_computed.py b/test/sql/test_computed.py index 886aa13b99..8fcfeff733 100644 --- a/test/sql/test_computed.py +++ b/test/sql/test_computed.py @@ -1,4 +1,3 @@ -# coding: utf-8 from sqlalchemy import Column from sqlalchemy import Computed from sqlalchemy import Integer diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index b1b731d66f..dbdab33077 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -705,15 +705,13 @@ class ConstraintGenTest(fixtures.TestBase, AssertsExecutionResults): Index("idx_winners", events.c.winner) eq_( - set(ix.name for ix in events.indexes), - set( - [ - "ix_events_name", - "ix_events_location", - "sport_announcer", - "idx_winners", - ] - ), + {ix.name for ix in events.indexes}, + { + "ix_events_name", + "ix_events_location", + "sport_announcer", + "idx_winners", + }, ) self.assert_sql_execution( diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 1249529f38..cc7daf4017 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -552,14 +552,14 @@ class DefaultRoundTripTest(fixtures.TablesTest): assert r.lastrow_has_defaults() eq_( set(r.context.postfetch_cols), - set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]), + {t.c.col3, t.c.col5, t.c.col4, t.c.col6}, ) r = connection.execute(t.insert().inline()) assert r.lastrow_has_defaults() eq_( set(r.context.postfetch_cols), - set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]), + {t.c.col3, t.c.col5, t.c.col4, t.c.col6}, ) connection.execute(t.insert()) @@ -599,7 +599,7 @@ class DefaultRoundTripTest(fixtures.TablesTest): eq_( set(r.context.postfetch_cols), - set([t.c.col3, t.c.col5, t.c.col4, t.c.col6]), + {t.c.col3, t.c.col5, t.c.col4, t.c.col6}, ) eq_( diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py index f98e7297d6..5b7e5ebbe3 100644 --- a/test/sql/test_delete.py +++ b/test/sql/test_delete.py @@ -1,5 +1,3 @@ -#! coding:utf-8 - from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import exc diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index ae34b0c0fa..fdfb87f724 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -1,5 +1,3 @@ -#! coding: utf-8 - from sqlalchemy import alias from sqlalchemy import and_ from sqlalchemy import bindparam diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 5e46808b3a..158707c6ae 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -344,7 +344,7 @@ class TraversalTest( foo, bar = CustomObj("foo", String), CustomObj("bar", String) bin_ = foo == bar set(ClauseVisitor().iterate(bin_)) - assert set(ClauseVisitor().iterate(bin_)) == set([foo, bar, bin_]) + assert set(ClauseVisitor().iterate(bin_)) == {foo, bar, bin_} class BinaryEndpointTraversalTest(fixtures.TestBase): @@ -726,7 +726,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): assert c1 == str(clause) assert str(clause2) == c1 + " SOME MODIFIER=:lala" assert list(clause._bindparams.keys()) == ["bar"] - assert set(clause2._bindparams.keys()) == set(["bar", "lala"]) + assert set(clause2._bindparams.keys()) == {"bar", "lala"} def test_select(self): s2 = select(t1) @@ -2209,8 +2209,8 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): e = sql_util.ClauseAdapter( b, - include_fn=lambda x: x in set([a.c.id]), - equivalents={a.c.id: set([a.c.id])}, + include_fn=lambda x: x in {a.c.id}, + equivalents={a.c.id: {a.c.id}}, ).traverse(e) assert str(e) == "a_1.id = a.xxx_id" @@ -2225,7 +2225,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # asking for a nonexistent col. corresponding_column should prevent # endless depth. adapt = sql_util.ClauseAdapter( - b, equivalents={a.c.x: set([c.c.x]), c.c.x: set([a.c.x])} + b, equivalents={a.c.x: {c.c.x}, c.c.x: {a.c.x}} ) assert adapt._corresponding_column(a.c.x, False) is None @@ -2240,7 +2240,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): # two levels of indirection from c.x->b.x->a.x, requires recursive # corresponding_column call adapt = sql_util.ClauseAdapter( - alias, equivalents={b.c.x: set([a.c.x]), c.c.x: set([b.c.x])} + alias, equivalents={b.c.x: {a.c.x}, c.c.x: {b.c.x}} ) assert adapt._corresponding_column(a.c.x, False) is alias.c.x assert adapt._corresponding_column(c.c.x, False) is alias.c.x @@ -2535,9 +2535,9 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL): def _table(name): return table(name, column("col1"), column("col2"), column("col3")) - table1, table2, table3, table4 = [ + table1, table2, table3, table4 = ( _table(name) for name in ("table1", "table2", "table3", "table4") - ] + ) def test_splice(self): t1, t2, t3, t4 = table1, table2, table1.alias(), table2.alias() diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index f92dd8496b..c97c136249 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -388,7 +388,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): def __init__(self, *args): args = args + (3,) - super(MyFunction, self).__init__(*args) + super().__init__(*args) self.assert_compile( func.my_func(1, 2), "my_func(:my_func_1, :my_func_2, :my_func_3)" diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 395fe16d3e..ac9ac4022b 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -1,4 +1,3 @@ -#! coding:utf-8 from __future__ import annotations from typing import Tuple @@ -1609,14 +1608,12 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): stmt = table.insert().values(values) eq_( - dict( - [ - (k, v.type._type_affinity) - for (k, v) in stmt.compile( - dialect=postgresql.dialect() - ).binds.items() - ] - ), + { + k: v.type._type_affinity + for (k, v) in stmt.compile( + dialect=postgresql.dialect() + ).binds.items() + }, { "foo": Integer, "data_m2": String, @@ -1757,14 +1754,12 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): stmt = table.insert().values(values) eq_( - dict( - [ - (k, v.type._type_affinity) - for (k, v) in stmt.compile( - dialect=postgresql.dialect() - ).binds.items() - ] - ), + { + k: v.type._type_affinity + for (k, v) in stmt.compile( + dialect=postgresql.dialect() + ).binds.items() + }, { "foo": Integer, "data_m2": String, diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index a74c5811c3..40ae2a65c8 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -795,7 +795,7 @@ class LabelLengthTest(fixtures.TestBase, AssertsCompiledSQL): compiled = stmt.compile(dialect=dialect) eq_( set(compiled._create_result_map()), - set(["tablename_columnn_1", "tablename_columnn_2"]), + {"tablename_columnn_1", "tablename_columnn_2"}, ) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 6d93cb234e..af8c15f98f 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -124,10 +124,10 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): class MyColumn(schema.Column): def __init__(self, *args, **kw): self.widget = kw.pop("widget", None) - super(MyColumn, self).__init__(*args, **kw) + super().__init__(*args, **kw) def _copy(self, *arg, **kw): - c = super(MyColumn, self)._copy(*arg, **kw) + c = super()._copy(*arg, **kw) c.widget = self.widget return c @@ -160,7 +160,7 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): Table("t2", metadata, Column("x", Integer), schema="bar") Table("t3", metadata, Column("x", Integer)) - eq_(metadata._schemas, set(["foo", "bar"])) + eq_(metadata._schemas, {"foo", "bar"}) eq_(len(metadata.tables), 3) def test_schema_collection_remove(self): @@ -171,11 +171,11 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): t3 = Table("t3", metadata, Column("x", Integer), schema="bar") metadata.remove(t3) - eq_(metadata._schemas, set(["foo", "bar"])) + eq_(metadata._schemas, {"foo", "bar"}) eq_(len(metadata.tables), 2) metadata.remove(t1) - eq_(metadata._schemas, set(["bar"])) + eq_(metadata._schemas, {"bar"}) eq_(len(metadata.tables), 1) def test_schema_collection_remove_all(self): @@ -1778,15 +1778,15 @@ class TableTest(fixtures.TestBase, AssertsCompiledSQL): fk3 = ForeignKeyConstraint(["b", "c"], ["r.x", "r.y"]) t1.append_column(Column("b", Integer, fk1)) - eq_(t1.foreign_key_constraints, set([fk1.constraint])) + eq_(t1.foreign_key_constraints, {fk1.constraint}) t1.append_column(Column("c", Integer, fk2)) - eq_(t1.foreign_key_constraints, set([fk1.constraint, fk2.constraint])) + eq_(t1.foreign_key_constraints, {fk1.constraint, fk2.constraint}) t1.append_constraint(fk3) eq_( t1.foreign_key_constraints, - set([fk1.constraint, fk2.constraint, fk3]), + {fk1.constraint, fk2.constraint, fk3}, ) def test_c_immutable(self): @@ -2167,20 +2167,16 @@ class SchemaTypeTest(fixtures.TestBase): evt_targets = () def _set_table(self, column, table): - super(SchemaTypeTest.TrackEvents, self)._set_table(column, table) + super()._set_table(column, table) self.column = column self.table = table def _on_table_create(self, target, bind, **kw): - super(SchemaTypeTest.TrackEvents, self)._on_table_create( - target, bind, **kw - ) + super()._on_table_create(target, bind, **kw) self.evt_targets += (target,) def _on_metadata_create(self, target, bind, **kw): - super(SchemaTypeTest.TrackEvents, self)._on_metadata_create( - target, bind, **kw - ) + super()._on_metadata_create(target, bind, **kw) self.evt_targets += (target,) # TODO: Enum and Boolean put TypeEngine first. Changing that here @@ -2951,7 +2947,7 @@ class ConstraintTest(fixtures.TestBase): return t1, t2, t3 def _assert_index_col_x(self, t, i, columns=True): - eq_(t.indexes, set([i])) + eq_(t.indexes, {i}) if columns: eq_(list(i.columns), [t.c.x]) else: @@ -3075,7 +3071,7 @@ class ConstraintTest(fixtures.TestBase): idx = Index("bar", MyThing(), t.c.y) - eq_(set(t.indexes), set([idx])) + eq_(set(t.indexes), {idx}) def test_clauseelement_extraction_three(self): t = Table("t", MetaData(), Column("x", Integer), Column("y", Integer)) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 79ca00e143..e00cacad89 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -702,7 +702,7 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): - super(MyInteger.comparator_factory, self).__init__(expr) + super().__init__(expr) def __add__(self, other): return self.expr.op("goofy")(other) @@ -721,7 +721,7 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): class comparator_factory(TypeDecorator.Comparator): def __init__(self, expr): - super(MyInteger.comparator_factory, self).__init__(expr) + super().__init__(expr) def __add__(self, other): return self.expr.op("goofy")(other) @@ -742,7 +742,7 @@ class TypeDecoratorTypeDecoratorComparatorTest( class comparator_factory(TypeDecorator.Comparator): def __init__(self, expr): - super(MyIntegerOne.comparator_factory, self).__init__(expr) + super().__init__(expr) def __add__(self, other): return self.expr.op("goofy")(other) @@ -764,9 +764,7 @@ class TypeDecoratorWVariantComparatorTest( class SomeOtherInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): - super(SomeOtherInteger.comparator_factory, self).__init__( - expr - ) + super().__init__(expr) def __add__(self, other): return self.expr.op("not goofy")(other) @@ -780,7 +778,7 @@ class TypeDecoratorWVariantComparatorTest( class comparator_factory(TypeDecorator.Comparator): def __init__(self, expr): - super(MyInteger.comparator_factory, self).__init__(expr) + super().__init__(expr) def __add__(self, other): return self.expr.op("goofy")(other) @@ -798,7 +796,7 @@ class CustomEmbeddedinTypeDecoratorTest( class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): - super(MyInteger.comparator_factory, self).__init__(expr) + super().__init__(expr) def __add__(self, other): return self.expr.op("goofy")(other) @@ -818,7 +816,7 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): class MyInteger(Integer): class comparator_factory(TypeEngine.Comparator): def __init__(self, expr): - super(MyInteger.comparator_factory, self).__init__(expr) + super().__init__(expr) def foob(self, other): return self.expr.op("foob")(other) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index ef94cc089e..54943897e1 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1586,10 +1586,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(t1.c.name == "t1 #10") .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1599,10 +1597,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(t1.c.t1_id < 12) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1617,10 +1613,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(t2.c.name == "t2 #20") .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1630,10 +1624,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(t2.c.t2_id < 29) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1648,10 +1640,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(t3.c.name == "t3 #30") .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1661,10 +1651,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(t3.c.t3_id < 39) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1692,10 +1680,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(and_(t1.c.t1_id < 19, t3.c.t3_id < 39)) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1711,10 +1697,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(and_(t1.c.name == "t1 #10", t2.c.name == "t2 #20")) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1724,10 +1708,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(and_(t1.c.t1_id < 12, t2.c.t2_id < 39)) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1748,10 +1730,8 @@ class JoinTest(fixtures.TablesTest): ) ) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1761,10 +1741,8 @@ class JoinTest(fixtures.TablesTest): select(t1.c.t1_id, t2.c.t2_id, t3.c.t3_id) .where(and_(t1.c.t1_id < 19, t2.c.t2_id < 29, t3.c.t3_id < 39)) .select_from( - ( - t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( - t3, criteria - ) + t1.outerjoin(t2, t1.c.t1_id == t2.c.t1_id).outerjoin( + t3, criteria ) ) ) @@ -1791,7 +1769,7 @@ class JoinTest(fixtures.TablesTest): .where( t1.c.name == "t1 #10", ) - .select_from((t1.join(t2).outerjoin(t3, criteria))) + .select_from(t1.join(t2).outerjoin(t3, criteria)) ) self.assertRows(expr, [(10, 20, 30)]) @@ -1800,7 +1778,7 @@ class JoinTest(fixtures.TablesTest): .where( t2.c.name == "t2 #20", ) - .select_from((t1.join(t2).outerjoin(t3, criteria))) + .select_from(t1.join(t2).outerjoin(t3, criteria)) ) self.assertRows(expr, [(10, 20, 30)]) @@ -1809,7 +1787,7 @@ class JoinTest(fixtures.TablesTest): .where( t3.c.name == "t3 #30", ) - .select_from((t1.join(t2).outerjoin(t3, criteria))) + .select_from(t1.join(t2).outerjoin(t3, criteria)) ) self.assertRows(expr, [(10, 20, 30)]) @@ -1818,7 +1796,7 @@ class JoinTest(fixtures.TablesTest): .where( and_(t1.c.name == "t1 #10", t2.c.name == "t2 #20"), ) - .select_from((t1.join(t2).outerjoin(t3, criteria))) + .select_from(t1.join(t2).outerjoin(t3, criteria)) ) self.assertRows(expr, [(10, 20, 30)]) @@ -1827,7 +1805,7 @@ class JoinTest(fixtures.TablesTest): .where( and_(t2.c.name == "t2 #20", t3.c.name == "t3 #30"), ) - .select_from((t1.join(t2).outerjoin(t3, criteria))) + .select_from(t1.join(t2).outerjoin(t3, criteria)) ) self.assertRows(expr, [(10, 20, 30)]) @@ -1840,7 +1818,7 @@ class JoinTest(fixtures.TablesTest): t3.c.name == "t3 #30", ), ) - .select_from((t1.join(t2).outerjoin(t3, criteria))) + .select_from(t1.join(t2).outerjoin(t3, criteria)) ) self.assertRows(expr, [(10, 20, 30)]) diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 7d90bc67b1..62ec007503 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -1,5 +1,3 @@ -#!coding: utf-8 - from sqlalchemy import CheckConstraint from sqlalchemy import Column from sqlalchemy import column @@ -886,9 +884,7 @@ class PreparerTest(fixtures.TestBase): def test_unformat_custom(self): class Custom(compiler.IdentifierPreparer): def __init__(self, dialect): - super(Custom, self).__init__( - dialect, initial_quote="`", final_quote="`" - ) + super().__init__(dialect, initial_quote="`", final_quote="`") def _escape_identifier(self, value): return value.replace("`", "``") @@ -1003,13 +999,13 @@ class QuotedIdentTest(fixtures.TestBase): def test_apply_map_quoted(self): q1 = _anonymous_label(quoted_name("x%s", True)) - q2 = q1.apply_map(("bar")) + q2 = q1.apply_map("bar") eq_(q2, "xbar") eq_(q2.quote, True) def test_apply_map_plain(self): q1 = _anonymous_label(quoted_name("x%s", None)) - q2 = q1.apply_map(("bar")) + q2 = q1.apply_map("bar") eq_(q2, "xbar") self._assert_not_quoted(q2) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index fa86d75ee8..b856acfd32 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1195,13 +1195,11 @@ class CursorResultTest(fixtures.TablesTest): row = result.first() eq_( - set( - [ - users.c.user_id in row._mapping, - addresses.c.user_id in row._mapping, - ] - ), - set([True]), + { + users.c.user_id in row._mapping, + addresses.c.user_id in row._mapping, + }, + {True}, ) @testing.combinations( @@ -3357,7 +3355,7 @@ class AlternateCursorResultTest(fixtures.TablesTest): def test_handle_error_in_fetch(self, strategy_cls, method_name): class cursor: def raise_(self): - raise IOError("random non-DBAPI error during cursor operation") + raise OSError("random non-DBAPI error during cursor operation") def fetchone(self): self.raise_() @@ -3390,7 +3388,7 @@ class AlternateCursorResultTest(fixtures.TablesTest): def test_buffered_row_close_error_during_fetchone(self): def raise_(**kw): - raise IOError("random non-DBAPI error during cursor operation") + raise OSError("random non-DBAPI error during cursor operation") with self._proxy_fixture(_cursor.BufferedRowCursorFetchStrategy): with self.engine.connect() as conn: diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index a467bb0a3b..baa8d89611 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -269,15 +269,11 @@ class SelectableTest( eq_( s1.selected_columns.foo.proxy_set, - set( - [s1.selected_columns.foo, scalar_select, scalar_select.element] - ), + {s1.selected_columns.foo, scalar_select, scalar_select.element}, ) eq_( s2.selected_columns.foo.proxy_set, - set( - [s2.selected_columns.foo, scalar_select, scalar_select.element] - ), + {s2.selected_columns.foo, scalar_select, scalar_select.element}, ) assert ( @@ -295,11 +291,11 @@ class SelectableTest( eq_( s1.c.foo.proxy_set, - set([s1.c.foo, scalar_select, scalar_select.element]), + {s1.c.foo, scalar_select, scalar_select.element}, ) eq_( s2.c.foo.proxy_set, - set([s2.c.foo, scalar_select, scalar_select.element]), + {s2.c.foo, scalar_select, scalar_select.element}, ) assert s1.corresponding_column(scalar_select) is s1.c.foo @@ -616,7 +612,7 @@ class SelectableTest( s2c1 = s2._clone() s3c1 = s3._clone() - eq_(base._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]), set([s1c1])) + eq_(base._cloned_intersection([s1c1, s3c1], [s2c1, s1c2]), {s1c1}) def test_cloned_difference(self): t1 = table("t1", column("x")) @@ -633,7 +629,7 @@ class SelectableTest( eq_( base._cloned_difference([s1c1, s2c1, s3c1], [s2c1, s1c2]), - set([s3c1]), + {s3c1}, ) def test_distance_on_aliases(self): @@ -1940,13 +1936,13 @@ class RefreshForNewColTest(fixtures.TestBase): q = Column("q", Integer) a.append_column(q) a._refresh_for_new_column(q) - eq_(a.foreign_keys, set([fk])) + eq_(a.foreign_keys, {fk}) fk2 = ForeignKey("g.id") p = Column("p", Integer, fk2) a.append_column(p) a._refresh_for_new_column(p) - eq_(a.foreign_keys, set([fk, fk2])) + eq_(a.foreign_keys, {fk, fk2}) def test_fk_join(self): m = MetaData() @@ -1960,13 +1956,13 @@ class RefreshForNewColTest(fixtures.TestBase): q = Column("q", Integer) b.append_column(q) j._refresh_for_new_column(q) - eq_(j.foreign_keys, set([fk])) + eq_(j.foreign_keys, {fk}) fk2 = ForeignKey("g.id") p = Column("p", Integer, fk2) b.append_column(p) j._refresh_for_new_column(p) - eq_(j.foreign_keys, set([fk, fk2])) + eq_(j.foreign_keys, {fk, fk2}) class AnonLabelTest(fixtures.TestBase): @@ -2641,10 +2637,10 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): ) s1 = select(t1, t2) s2 = s1.reduce_columns(only_synonyms=False) - eq_(set(s2.selected_columns), set([t1.c.x, t1.c.y, t2.c.q])) + eq_(set(s2.selected_columns), {t1.c.x, t1.c.y, t2.c.q}) s2 = s1.reduce_columns() - eq_(set(s2.selected_columns), set([t1.c.x, t1.c.y, t2.c.z, t2.c.q])) + eq_(set(s2.selected_columns), {t1.c.x, t1.c.y, t2.c.z, t2.c.q}) def test_reduce_only_synonym_fk(self): m = MetaData() @@ -2664,13 +2660,11 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): s1 = s1.reduce_columns(only_synonyms=True) eq_( set(s1.selected_columns), - set( - [ - s1.selected_columns.x, - s1.selected_columns.y, - s1.selected_columns.q, - ] - ), + { + s1.selected_columns.x, + s1.selected_columns.y, + s1.selected_columns.q, + }, ) def test_reduce_only_synonym_lineage(self): @@ -2688,7 +2682,7 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): s2 = select(t1, s1).where(t1.c.x == s1.c.x).where(s1.c.y == t1.c.z) eq_( set(s2.reduce_columns().selected_columns), - set([t1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z]), + {t1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z}, ) # reverse order, s1.c.x wins @@ -2696,7 +2690,7 @@ class ReduceTest(fixtures.TestBase, AssertsExecutionResults): s2 = select(s1, t1).where(t1.c.x == s1.c.x).where(s1.c.y == t1.c.z) eq_( set(s2.reduce_columns().selected_columns), - set([s1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z]), + {s1.c.x, t1.c.y, t1.c.z, s1.c.y, s1.c.z}, ) def test_reduce_aliased_join(self): @@ -2994,7 +2988,7 @@ class AnnotationsTest(fixtures.TestBase): for obj in [t, t.c.x, a, t.c.x > 1, (t.c.x > 1).label(None)]: annot = obj._annotate({}) - eq_(set([obj]), set([annot])) + eq_({obj}, {annot}) def test_clone_annotations_dont_hash(self): t = table("t", column("x")) @@ -3005,7 +2999,7 @@ class AnnotationsTest(fixtures.TestBase): for obj in [s, s2]: annot = obj._annotate({}) - ne_(set([obj]), set([annot])) + ne_({obj}, {annot}) def test_replacement_traverse_preserve(self): """test that replacement traverse that hits an unannotated column @@ -3802,11 +3796,11 @@ class ResultMapTest(fixtures.TestBase): def _mapping(self, stmt): compiled = stmt.compile() - return dict( - (elem, key) + return { + elem: key for key, elements in compiled._create_result_map().items() for elem in elements[1] - ) + } def test_select_label_alt_name(self): t = self._fixture() diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 805eacec63..b5d9ae7407 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -374,7 +374,7 @@ class BindParamTest(fixtures.TestBase, AssertsCompiledSQL): ) def _assert_type_map(self, t, compare): - map_ = dict((b.key, b.type) for b in t._bindparams.values()) + map_ = {b.key: b.type for b in t._bindparams.values()} for k in compare: assert compare[k]._type_affinity is map_[k]._type_affinity @@ -642,11 +642,11 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): def _mapping(self, stmt): compiled = stmt.compile() - return dict( - (elem, key) + return { + elem: key for key, elements in compiled._create_result_map().items() for elem in elements[1] - ) + } def test_select_label_alt_name(self): t = self._xy_table_fixture() @@ -815,7 +815,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): t = t.bindparams(bar=String) t = t.bindparams(bindparam("bat", value="bat")) - eq_(set(t.element._bindparams), set(["bat", "foo", "bar"])) + eq_(set(t.element._bindparams), {"bat", "foo", "bar"}) class TextErrorsTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 3b1df34987..d1b32186e9 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1,4 +1,3 @@ -# coding: utf-8 import datetime import decimal import importlib @@ -511,9 +510,9 @@ class _UserDefinedTypeFixture: cache_ok = True def bind_processor(self, dialect): - impl_processor = super(MyDecoratedType, self).bind_processor( - dialect - ) or (lambda value: value) + impl_processor = super().bind_processor(dialect) or ( + lambda value: value + ) def process(value): if value is None: @@ -523,7 +522,7 @@ class _UserDefinedTypeFixture: return process def result_processor(self, dialect, coltype): - impl_processor = super(MyDecoratedType, self).result_processor( + impl_processor = super().result_processor( dialect, coltype ) or (lambda value: value) @@ -577,9 +576,9 @@ class _UserDefinedTypeFixture: cache_ok = True def bind_processor(self, dialect): - impl_processor = super(MyUnicodeType, self).bind_processor( - dialect - ) or (lambda value: value) + impl_processor = super().bind_processor(dialect) or ( + lambda value: value + ) def process(value): if value is None: @@ -590,7 +589,7 @@ class _UserDefinedTypeFixture: return process def result_processor(self, dialect, coltype): - impl_processor = super(MyUnicodeType, self).result_processor( + impl_processor = super().result_processor( dialect, coltype ) or (lambda value: value) @@ -1070,7 +1069,7 @@ class UserDefinedTest( if dialect.name == "sqlite": return String(50) else: - return super(MyType, self).load_dialect_impl(dialect) + return super().load_dialect_impl(dialect) sl = dialects.sqlite.dialect() pg = dialects.postgresql.dialect() @@ -1143,7 +1142,7 @@ class UserDefinedTest( def test_user_defined_dialect_specific_args(self): class MyType(types.UserDefinedType): def __init__(self, foo="foo", **kwargs): - super(MyType, self).__init__() + super().__init__() self.foo = foo self.dialect_specific_args = kwargs diff --git a/test/sql/test_update.py b/test/sql/test_update.py index e93900bbda..cd7f992e22 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1525,7 +1525,7 @@ class UpdateFromMultiTableUpdateDefaultsTest( .where(users.c.name == "ed") ) - eq_(set(ret.prefetch_cols()), set([users.c.some_update])) + eq_(set(ret.prefetch_cols()), {users.c.some_update}) expected = [ (2, 8, "updated"), @@ -1552,7 +1552,7 @@ class UpdateFromMultiTableUpdateDefaultsTest( eq_( set(ret.prefetch_cols()), - set([users.c.some_update, foobar.c.some_update]), + {users.c.some_update, foobar.c.some_update}, ) expected = [ diff --git a/tools/format_docs_code.py b/tools/format_docs_code.py index 05e5e01f10..8f0b6e54d6 100644 --- a/tools/format_docs_code.py +++ b/tools/format_docs_code.py @@ -372,11 +372,11 @@ Use --report-doctest to ignore errors on plain code blocks. config = parse_pyproject_toml(home / "pyproject.toml") BLACK_MODE = Mode( - target_versions=set( + target_versions={ TargetVersion[val.upper()] for val in config.get("target_version", []) if val != "py27" - ), + }, line_length=config.get("line_length", DEFAULT_LINE_LENGTH) if args.project_line_length else DEFAULT_LINE_LENGTH,