from .base import _from_objects
from .base import _NONE_NAME
from .base import _SentinelDefaultCharacterization
-from .base import Executable
from .base import NO_ARG
-from .elements import ClauseElement
from .elements import quoted_name
-from .schema import Column
from .sqltypes import TupleType
-from .type_api import TypeEngine
from .visitors import prefix_anon_map
-from .visitors import Visitable
from .. import exc
from .. import util
from ..util import FastIntFlag
from ..util.typing import Literal
+from ..util.typing import Self
from ..util.typing import TupleAny
from ..util.typing import Unpack
from .annotation import _AnnotationDict
from .base import _AmbiguousTableNameMap
from .base import CompileState
+ from .base import Executable
from .cache_key import CacheKey
from .ddl import ExecutableDDLElement
from .dml import Insert
+ from .dml import Update
from .dml import UpdateBase
+ from .dml import UpdateDMLState
from .dml import ValuesBase
from .elements import _truncated_label
+ from .elements import BinaryExpression
from .elements import BindParameter
+ from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import False_
from .elements import Label
+ from .elements import Null
+ from .elements import True_
from .functions import Function
+ from .schema import Column
+ from .schema import Constraint
+ from .schema import ForeignKeyConstraint
+ from .schema import Index
+ from .schema import PrimaryKeyConstraint
from .schema import Table
+ from .schema import UniqueConstraint
+ from .selectable import _ColumnsClauseElement
from .selectable import AliasedReturnsRows
from .selectable import CompoundSelectState
from .selectable import CTE
from .selectable import Select
from .selectable import SelectState
from .type_api import _BindProcessorType
+ from .type_api import TypeDecorator
+ from .type_api import TypeEngine
+ from .type_api import UserDefinedType
+ from .visitors import Visitable
from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _DBAPIAnyExecuteParams
from ..engine.interfaces import Dialect
from ..engine.interfaces import SchemaTranslateMapType
+
_FromHintsType = Dict["FromClause", str]
RESERVED_WORDS = {
self.string = self.process(self.statement, **compile_kwargs)
if render_schema_translate:
+ assert schema_translate_map is not None
self.string = self.preparer._render_schema_translates(
self.string, schema_translate_map
)
raise exc.UnsupportedCompilationError(self, type(element)) from err
@property
- def sql_compiler(self):
+ def sql_compiler(self) -> SQLCompiler:
"""Return a Compiled that is capable of processing SQL expressions.
If this compiler is one, it would likely just return 'self'.
return len(self.stack) > 1
@property
- def sql_compiler(self):
+ def sql_compiler(self) -> Self:
return self
def construct_expanded_state(
return get
- def default_from(self):
+ def default_from(self) -> str:
"""Called when a SELECT statement has no froms, and no FROM clause is
to be appended.
return text
- def visit_null(self, expr, **kw):
+ def visit_null(self, expr: Null, **kw: Any) -> str:
return "NULL"
- def visit_true(self, expr, **kw):
+ def visit_true(self, expr: True_, **kw: Any) -> str:
if self.dialect.supports_native_boolean:
return "true"
else:
return "1"
- def visit_false(self, expr, **kw):
+ def visit_false(self, expr: False_, **kw: Any) -> str:
if self.dialect.supports_native_boolean:
return "false"
else:
% self.dialect.name
)
- def function_argspec(self, func, **kwargs):
+ def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(
)
def _generate_generic_binary(
- self, binary, opstring, eager_grouping=False, **kw
- ):
+ self,
+ binary: BinaryExpression[Any],
+ opstring: str,
+ eager_grouping: bool = False,
+ **kw: Any,
+ ) -> str:
_in_operator_expression = kw.get("_in_operator_expression", False)
kw["_in_operator_expression"] = True
**kw,
)
- def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ def visit_regexp_match_op_binary(
+ self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
raise exc.CompileError(
"%s dialect does not support regular expressions"
% self.dialect.name
)
- def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ def visit_not_regexp_match_op_binary(
+ self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
raise exc.CompileError(
"%s dialect does not support regular expressions"
% self.dialect.name
)
- def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ def visit_regexp_replace_op_binary(
+ self, binary: BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
raise exc.CompileError(
"%s dialect does not support regular expression replacements"
% self.dialect.name
else:
return self.render_literal_value(value, bindparam.type)
- def render_literal_value(self, value, type_):
+ def render_literal_value(
+ self, value: Any, type_: sqltypes.TypeEngine[Any]
+ ) -> str:
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind parameters
def get_select_hint_text(self, byfroms):
return None
- def get_from_hint_text(self, table, text):
+ def get_from_hint_text(
+ self, table: FromClause, text: Optional[str]
+ ) -> Optional[str]:
return None
def get_crud_hint_text(self, table, text):
else:
return "WITH"
- def get_select_precolumns(self, select, **kw):
+ def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
"""Called when building a ``SELECT`` statement, position is just
before column list.
def returning_clause(
self,
stmt: UpdateBase,
- returning_cols: Sequence[ColumnElement[Any]],
+ returning_cols: Sequence[_ColumnsClauseElement],
*,
populate_result_map: bool,
**kw: Any,
else:
return None
- def visit_update(self, update_stmt, visiting_cte=None, **kw):
- compile_state = update_stmt._compile_state_factory(
- update_stmt, self, **kw
+ def visit_update(
+ self,
+ update_stmt: Update,
+ visiting_cte: Optional[CTE] = None,
+ **kw: Any,
+ ) -> str:
+ compile_state = update_stmt._compile_state_factory( # type: ignore[call-arg] # noqa: E501
+ update_stmt, self, **kw # type: ignore[arg-type]
)
- update_stmt = compile_state.statement
+ if TYPE_CHECKING:
+ assert isinstance(compile_state, UpdateDMLState)
+ update_stmt = compile_state.statement # type: ignore[assignment]
if visiting_cte is not None:
kw["visiting_cte"] = visiting_cte
return text
def delete_extra_from_clause(
- self, update_stmt, from_table, extra_froms, from_hints, **kw
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
def returning_clause(
self,
stmt: UpdateBase,
- returning_cols: Sequence[ColumnElement[Any]],
+ returning_cols: Sequence[_ColumnsClauseElement],
*,
populate_result_map: bool,
**kw: Any,
)
def delete_extra_from_clause(
- self, update_stmt, from_table, extra_froms, from_hints, **kw
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
):
kw["asfrom"] = True
return ", " + ", ".join(
compile_kwargs: Mapping[str, Any] = ...,
): ...
- @util.memoized_property
- def sql_compiler(self):
+ @util.ro_memoized_property
+ def sql_compiler(self) -> SQLCompiler:
return self.dialect.statement_compiler(
self.dialect, None, schema_translate_map=self.schema_translate_map
)
def visit_drop_view(self, drop, **kw):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
- def _verify_index_table(self, index):
+ def _verify_index_table(self, index: Index) -> None:
if index.table is None:
raise exc.CompileError(
"Index '%s' is not associated with any table." % index.name
return text + self._prepared_index_name(index, include_schema=True)
- def _prepared_index_name(self, index, include_schema=False):
+ def _prepared_index_name(
+ self, index: Index, include_schema: bool = False
+ ) -> str:
if index.table is not None:
effective_schema = self.preparer.schema_for_object(index.table)
else:
def post_create_table(self, table):
return ""
- def get_column_default_string(self, column):
+ def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
if isinstance(column.server_default, schema.DefaultClause):
return self.render_default_string(column.server_default.arg)
else:
return None
- def render_default_string(self, default):
+ def render_default_string(self, default: Union[Visitable, str]) -> str:
if isinstance(default, str):
return self.sql_compiler.render_literal_value(
default, sqltypes.STRINGTYPE
text += self.define_constraint_deferrability(constraint)
return text
- def visit_primary_key_constraint(self, constraint, **kw):
+ def visit_primary_key_constraint(
+ self, constraint: PrimaryKeyConstraint, **kw: Any
+ ) -> str:
if len(constraint) == 0:
return ""
text = ""
return preparer.format_table(table)
- def visit_unique_constraint(self, constraint, **kw):
+ def visit_unique_constraint(
+ self, constraint: UniqueConstraint, **kw: Any
+ ) -> str:
if len(constraint) == 0:
return ""
text = ""
text += self.define_constraint_deferrability(constraint)
return text
- def define_unique_constraint_distinct(self, constraint, **kw):
+ def define_unique_constraint_distinct(
+ self, constraint: UniqueConstraint, **kw: Any
+ ) -> str:
return ""
- def define_constraint_cascades(self, constraint):
+ def define_constraint_cascades(
+ self, constraint: ForeignKeyConstraint
+ ) -> str:
text = ""
if constraint.ondelete is not None:
text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
)
return text
- def define_constraint_deferrability(self, constraint):
+ def define_constraint_deferrability(self, constraint: Constraint) -> str:
text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
class GenericTypeCompiler(TypeCompiler):
- def visit_FLOAT(self, type_, **kw):
+ def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
return "FLOAT"
- def visit_DOUBLE(self, type_, **kw):
+ def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
return "DOUBLE"
- def visit_DOUBLE_PRECISION(self, type_, **kw):
+ def visit_DOUBLE_PRECISION(
+ self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
+ ) -> str:
return "DOUBLE PRECISION"
- def visit_REAL(self, type_, **kw):
+ def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
return "REAL"
- def visit_NUMERIC(self, type_, **kw):
+ def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
"scale": type_.scale,
}
- def visit_DECIMAL(self, type_, **kw):
+ def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
"scale": type_.scale,
}
- def visit_INTEGER(self, type_, **kw):
+ def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
return "INTEGER"
- def visit_SMALLINT(self, type_, **kw):
+ def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
return "SMALLINT"
- def visit_BIGINT(self, type_, **kw):
+ def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
return "BIGINT"
- def visit_TIMESTAMP(self, type_, **kw):
+ def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
return "TIMESTAMP"
- def visit_DATETIME(self, type_, **kw):
+ def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
return "DATETIME"
- def visit_DATE(self, type_, **kw):
+ def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
return "DATE"
- def visit_TIME(self, type_, **kw):
+ def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
return "TIME"
- def visit_CLOB(self, type_, **kw):
+ def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
return "CLOB"
- def visit_NCLOB(self, type_, **kw):
+ def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
return "NCLOB"
- def _render_string_type(self, type_, name, length_override=None):
+ def _render_string_type(
+ self, name: str, length: Optional[int], collation: Optional[str]
+ ) -> str:
text = name
- if length_override:
- text += "(%d)" % length_override
- elif type_.length:
- text += "(%d)" % type_.length
- if type_.collation:
- text += ' COLLATE "%s"' % type_.collation
+ if length:
+ text += f"({length})"
+ if collation:
+ text += f' COLLATE "{collation}"'
return text
- def visit_CHAR(self, type_, **kw):
- return self._render_string_type(type_, "CHAR")
+ def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
+ return self._render_string_type("CHAR", type_.length, type_.collation)
- def visit_NCHAR(self, type_, **kw):
- return self._render_string_type(type_, "NCHAR")
+ def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
+ return self._render_string_type("NCHAR", type_.length, type_.collation)
- def visit_VARCHAR(self, type_, **kw):
- return self._render_string_type(type_, "VARCHAR")
+ def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
+ return self._render_string_type(
+ "VARCHAR", type_.length, type_.collation
+ )
- def visit_NVARCHAR(self, type_, **kw):
- return self._render_string_type(type_, "NVARCHAR")
+ def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
+ return self._render_string_type(
+ "NVARCHAR", type_.length, type_.collation
+ )
- def visit_TEXT(self, type_, **kw):
- return self._render_string_type(type_, "TEXT")
+ def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
+ return self._render_string_type("TEXT", type_.length, type_.collation)
- def visit_UUID(self, type_, **kw):
+ def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
return "UUID"
- def visit_BLOB(self, type_, **kw):
+ def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
return "BLOB"
- def visit_BINARY(self, type_, **kw):
+ def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_VARBINARY(self, type_, **kw):
+ def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_BOOLEAN(self, type_, **kw):
+ def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
return "BOOLEAN"
- def visit_uuid(self, type_, **kw):
+ def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
if not type_.native_uuid or not self.dialect.supports_native_uuid:
- return self._render_string_type(type_, "CHAR", length_override=32)
+ return self._render_string_type("CHAR", length=32, collation=None)
else:
return self.visit_UUID(type_, **kw)
- def visit_large_binary(self, type_, **kw):
+ def visit_large_binary(
+ self, type_: sqltypes.LargeBinary, **kw: Any
+ ) -> str:
return self.visit_BLOB(type_, **kw)
- def visit_boolean(self, type_, **kw):
+ def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
return self.visit_BOOLEAN(type_, **kw)
- def visit_time(self, type_, **kw):
+ def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
return self.visit_TIME(type_, **kw)
- def visit_datetime(self, type_, **kw):
+ def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
return self.visit_DATETIME(type_, **kw)
- def visit_date(self, type_, **kw):
+ def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
return self.visit_DATE(type_, **kw)
- def visit_big_integer(self, type_, **kw):
+ def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
return self.visit_BIGINT(type_, **kw)
- def visit_small_integer(self, type_, **kw):
+ def visit_small_integer(
+ self, type_: sqltypes.SmallInteger, **kw: Any
+ ) -> str:
return self.visit_SMALLINT(type_, **kw)
- def visit_integer(self, type_, **kw):
+ def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
return self.visit_INTEGER(type_, **kw)
- def visit_real(self, type_, **kw):
+ def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
return self.visit_REAL(type_, **kw)
- def visit_float(self, type_, **kw):
+ def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
return self.visit_FLOAT(type_, **kw)
- def visit_double(self, type_, **kw):
+ def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
return self.visit_DOUBLE(type_, **kw)
- def visit_numeric(self, type_, **kw):
+ def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
return self.visit_NUMERIC(type_, **kw)
- def visit_string(self, type_, **kw):
+ def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
- def visit_unicode(self, type_, **kw):
+ def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
- def visit_text(self, type_, **kw):
+ def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(self, type_, **kw):
+ def visit_unicode_text(
+ self, type_: sqltypes.UnicodeText, **kw: Any
+ ) -> str:
return self.visit_TEXT(type_, **kw)
- def visit_enum(self, type_, **kw):
+ def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
"type on this Column?" % type_
)
- def visit_type_decorator(self, type_, **kw):
+ def visit_type_decorator(
+ self, type_: TypeDecorator[Any], **kw: Any
+ ) -> str:
return self.process(type_.type_engine(self.dialect), **kw)
- def visit_user_defined(self, type_, **kw):
+ def visit_user_defined(
+ self, type_: UserDefinedType[Any], **kw: Any
+ ) -> str:
return type_.get_col_spec(**kw)
def __init__(
self,
- dialect,
- initial_quote='"',
- final_quote=None,
- escape_quote='"',
- quote_case_sensitive_collations=True,
- omit_schema=False,
+ dialect: Dialect,
+ initial_quote: str = '"',
+ final_quote: Optional[str] = None,
+ escape_quote: str = '"',
+ quote_case_sensitive_collations: bool = True,
+ omit_schema: bool = False,
):
"""Construct a new ``IdentifierPreparer`` object.
prep._includes_none_schema_translate = includes_none
return prep
- def _render_schema_translates(self, statement, schema_translate_map):
+ def _render_schema_translates(
+ self, statement: str, schema_translate_map: SchemaTranslateMapType
+ ) -> str:
d = schema_translate_map
if None in d:
if not self._includes_none_schema_translate:
"schema_translate_map dictionaries."
)
- d["_none"] = d[None]
+ d["_none"] = d[None] # type: ignore[index]
def replace(m):
name = m.group(2)
else:
return collation_name
- def format_sequence(self, sequence, use_schema=True):
+ def format_sequence(
+ self, sequence: schema.Sequence, use_schema: bool = True
+ ) -> str:
name = self.quote(sequence.name)
effective_schema = self.schema_for_object(sequence)
return ident
@util.preload_module("sqlalchemy.sql.naming")
- def format_constraint(self, constraint, _alembic_quote=True):
+ def format_constraint(
+ self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
+ ) -> Optional[str]:
naming = util.preloaded.sql_naming
if constraint.name is _NONE_NAME:
else:
name = constraint.name
+ assert name is not None
if constraint.__visit_name__ == "index":
return self.truncate_and_render_index_name(
name, _alembic_quote=_alembic_quote
name, _alembic_quote=_alembic_quote
)
- def truncate_and_render_index_name(self, name, _alembic_quote=True):
+ def truncate_and_render_index_name(
+ self, name: str, _alembic_quote: bool = True
+ ) -> str:
# calculate these at format time so that ad-hoc changes
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
name, max_, _alembic_quote
)
- def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
+ def truncate_and_render_constraint_name(
+ self, name: str, _alembic_quote: bool = True
+ ) -> str:
# calculate these at format time so that ad-hoc changes
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
name, max_, _alembic_quote
)
- def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
+ def _truncate_and_render_maxlen_name(
+ self, name: str, max_: int, _alembic_quote: bool
+ ) -> str:
if isinstance(name, elements._truncated_label):
if len(name) > max_:
name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
else:
return self.quote(name)
- def format_index(self, index):
- return self.format_constraint(index)
+ def format_index(self, index: Index) -> str:
+ name = self.format_constraint(index)
+ assert name is not None
+ return name
- def format_table(self, table, use_schema=True, name=None):
+ def format_table(
+ self,
+ table: FromClause,
+ use_schema: bool = True,
+ name: Optional[str] = None,
+ ) -> str:
"""Prepare a quoted table and schema name."""
-
if name is None:
+ if TYPE_CHECKING:
+ assert isinstance(table, NamedFromClause)
name = table.name
result = self.quote(name)
def format_column(
self,
- column,
- use_table=False,
- name=None,
- table_name=None,
- use_schema=False,
- anon_map=None,
- ):
+ column: ColumnElement[Any],
+ use_table: bool = False,
+ name: Optional[str] = None,
+ table_name: Optional[str] = None,
+ use_schema: bool = False,
+ anon_map: Optional[Mapping[str, Any]] = None,
+ ) -> str:
"""Prepare a quoted column name."""
if name is None:
name = column.name
+ assert name is not None
if anon_map is not None and isinstance(
name, elements._truncated_label
)
return r
- def unformat_identifiers(self, identifiers):
+ def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers
import typing
from typing import Any
from typing import Callable
+from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import Protocol
from typing import Sequence as typing_Sequence
from typing import Tuple
+from typing import TypeVar
+from typing import Union
from . import roles
from .base import _generative
from .compiler import Compiled
from .compiler import DDLCompiler
from .elements import BindParameter
+ from .schema import Column
from .schema import Constraint
from .schema import ForeignKeyConstraint
+ from .schema import Index
from .schema import SchemaItem
- from .schema import Sequence
+ from .schema import Sequence as Sequence # noqa: F401
from .schema import Table
from .selectable import TableClause
from ..engine.base import Connection
from ..engine.interfaces import Dialect
from ..engine.interfaces import SchemaTranslateMapType
+_SI = TypeVar("_SI", bound=Union["SchemaItem", str])
+
class BaseDDLElement(ClauseElement):
"""The root of DDL constructs, including those that are sub-elements
def __call__(
self,
ddl: BaseDDLElement,
- target: SchemaItem,
+ target: Union[SchemaItem, str],
bind: Optional[Connection],
tables: Optional[List[Table]] = None,
state: Optional[Any] = None,
def _should_execute(
self,
ddl: BaseDDLElement,
- target: SchemaItem,
+ target: Union[SchemaItem, str],
bind: Optional[Connection],
compiler: Optional[DDLCompiler] = None,
**kw: Any,
"""
_ddl_if: Optional[DDLIf] = None
- target: Optional[SchemaItem] = None
+ target: Union[SchemaItem, str, None] = None
def _execute_on_connection(
self, connection, distilled_params, execution_options
)
-class _CreateDropBase(ExecutableDDLElement):
+class _CreateDropBase(ExecutableDDLElement, Generic[_SI]):
"""Base class for DDL constructs that represent CREATE and DROP or
equivalents.
"""
- def __init__(
- self,
- element,
- ):
+ def __init__(self, element: _SI) -> None:
self.element = self.target = element
self._ddl_if = getattr(element, "_ddl_if", None)
@property
def stringify_dialect(self):
+ assert not isinstance(self.element, str)
return self.element.create_drop_stringify_dialect
def _create_rule_disable(self, compiler):
return False
-class _CreateBase(_CreateDropBase):
- def __init__(self, element, if_not_exists=False):
+class _CreateBase(_CreateDropBase[_SI]):
+ def __init__(self, element: _SI, if_not_exists: bool = False) -> None:
super().__init__(element)
self.if_not_exists = if_not_exists
-class _DropBase(_CreateDropBase):
- def __init__(self, element, if_exists=False):
+class _DropBase(_CreateDropBase[_SI]):
+ def __init__(self, element: _SI, if_exists: bool = False) -> None:
super().__init__(element)
self.if_exists = if_exists
-class CreateSchema(_CreateBase):
+class CreateSchema(_CreateBase[str]):
"""Represent a CREATE SCHEMA statement.
The argument here is the string name of the schema.
self,
name: str,
if_not_exists: bool = False,
- ):
+ ) -> None:
"""Create a new :class:`.CreateSchema` construct."""
super().__init__(element=name, if_not_exists=if_not_exists)
-class DropSchema(_DropBase):
+class DropSchema(_DropBase[str]):
"""Represent a DROP SCHEMA statement.
The argument here is the string name of the schema.
name: str,
cascade: bool = False,
if_exists: bool = False,
- ):
+ ) -> None:
"""Create a new :class:`.DropSchema` construct."""
super().__init__(element=name, if_exists=if_exists)
self.cascade = cascade
-class CreateTable(_CreateBase):
+class CreateTable(_CreateBase["Table"]):
"""Represent a CREATE TABLE statement."""
__visit_name__ = "create_table"
typing_Sequence[ForeignKeyConstraint]
] = None,
if_not_exists: bool = False,
- ):
+ ) -> None:
"""Create a :class:`.CreateTable` construct.
:param element: a :class:`_schema.Table` that's the subject
self.include_foreign_key_constraints = include_foreign_key_constraints
-class _DropView(_DropBase):
+class _DropView(_DropBase["Table"]):
"""Semi-public 'DROP VIEW' construct.
Used by the test suite for dialect-agnostic drops of views.
class CreateConstraint(BaseDDLElement):
- def __init__(self, element: Constraint):
+ element: Constraint
+
+ def __init__(self, element: Constraint) -> None:
self.element = element
__visit_name__ = "create_column"
- def __init__(self, element):
+ element: Column[Any]
+
+ def __init__(self, element: Column[Any]) -> None:
self.element = element
-class DropTable(_DropBase):
+class DropTable(_DropBase["Table"]):
"""Represent a DROP TABLE statement."""
__visit_name__ = "drop_table"
- def __init__(self, element: Table, if_exists: bool = False):
+ def __init__(self, element: Table, if_exists: bool = False) -> None:
"""Create a :class:`.DropTable` construct.
:param element: a :class:`_schema.Table` that's the subject
super().__init__(element, if_exists=if_exists)
-class CreateSequence(_CreateBase):
+class CreateSequence(_CreateBase["Sequence"]):
"""Represent a CREATE SEQUENCE statement."""
__visit_name__ = "create_sequence"
- def __init__(self, element: Sequence, if_not_exists: bool = False):
- super().__init__(element, if_not_exists=if_not_exists)
-
-class DropSequence(_DropBase):
+class DropSequence(_DropBase["Sequence"]):
"""Represent a DROP SEQUENCE statement."""
__visit_name__ = "drop_sequence"
- def __init__(self, element: Sequence, if_exists: bool = False):
- super().__init__(element, if_exists=if_exists)
-
-class CreateIndex(_CreateBase):
+class CreateIndex(_CreateBase["Index"]):
"""Represent a CREATE INDEX statement."""
__visit_name__ = "create_index"
- def __init__(self, element, if_not_exists=False):
+ def __init__(self, element: Index, if_not_exists: bool = False) -> None:
"""Create a :class:`.Createindex` construct.
:param element: a :class:`_schema.Index` that's the subject
super().__init__(element, if_not_exists=if_not_exists)
-class DropIndex(_DropBase):
+class DropIndex(_DropBase["Index"]):
"""Represent a DROP INDEX statement."""
__visit_name__ = "drop_index"
- def __init__(self, element, if_exists=False):
+ def __init__(self, element: Index, if_exists: bool = False) -> None:
"""Create a :class:`.DropIndex` construct.
:param element: a :class:`_schema.Index` that's the subject
super().__init__(element, if_exists=if_exists)
-class AddConstraint(_CreateBase):
+class AddConstraint(_CreateBase["Constraint"]):
"""Represent an ALTER TABLE ADD CONSTRAINT statement."""
__visit_name__ = "add_constraint"
element: Constraint,
*,
isolate_from_table: bool = True,
- ):
+ ) -> None:
"""Construct a new :class:`.AddConstraint` construct.
:param element: a :class:`.Constraint` object
)
-class DropConstraint(_DropBase):
+class DropConstraint(_DropBase["Constraint"]):
"""Represent an ALTER TABLE DROP CONSTRAINT statement."""
__visit_name__ = "drop_constraint"
if_exists: bool = False,
isolate_from_table: bool = True,
**kw: Any,
- ):
+ ) -> None:
"""Construct a new :class:`.DropConstraint` construct.
:param element: a :class:`.Constraint` object
)
-class SetTableComment(_CreateDropBase):
+class SetTableComment(_CreateDropBase["Table"]):
"""Represent a COMMENT ON TABLE IS statement."""
__visit_name__ = "set_table_comment"
-class DropTableComment(_CreateDropBase):
+class DropTableComment(_CreateDropBase["Table"]):
"""Represent a COMMENT ON TABLE '' statement.
Note this varies a lot across database backends.
__visit_name__ = "drop_table_comment"
-class SetColumnComment(_CreateDropBase):
+class SetColumnComment(_CreateDropBase["Column[Any]"]):
"""Represent a COMMENT ON COLUMN IS statement."""
__visit_name__ = "set_column_comment"
-class DropColumnComment(_CreateDropBase):
+class DropColumnComment(_CreateDropBase["Column[Any]"]):
"""Represent a COMMENT ON COLUMN IS NULL statement."""
__visit_name__ = "drop_column_comment"
-class SetConstraintComment(_CreateDropBase):
+class SetConstraintComment(_CreateDropBase["Constraint"]):
"""Represent a COMMENT ON CONSTRAINT IS statement."""
__visit_name__ = "set_constraint_comment"
-class DropConstraintComment(_CreateDropBase):
+class DropConstraintComment(_CreateDropBase["Constraint"]):
"""Represent a COMMENT ON CONSTRAINT IS NULL statement."""
__visit_name__ = "drop_constraint_comment"