from .selectable import SelectState
from .type_api import _BindProcessorType
from ..engine.cursor import CursorResultMetaData
- from ..engine.default import DefaultDialect
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _DBAPIAnyExecuteParams
from ..engine.interfaces import _DBAPIMultiExecuteParams
self.string = self.process(self.statement, **compile_kwargs)
if render_schema_translate:
+ assert schema_translate_map is not None
self.string = self.preparer._render_schema_translates(
- self.string, schema_translate_map # type: ignore[arg-type]
+ self.string, schema_translate_map
)
self.state = CompilerState.STRING_APPLIED
def get_select_hint_text(self, byfroms):
return None
- def get_from_hint_text(self, table: Any, text: str | None) -> str | None:
+ def get_from_hint_text(
+ self, table: FromClause, text: str | None
+ ) -> str | None:
return None
def get_crud_hint_text(self, table, text):
def visit_update(
self, update_stmt: "Update", visiting_cte: CTE | None = None, **kw: Any
) -> str:
- compile_state = update_stmt._compile_state_factory( # type: ignore
- update_stmt, self, **kw # type: ignore
+ compile_state = update_stmt._compile_state_factory( # type: ignore[call-arg] # NOQA: E501
+ update_stmt, self, **kw # type: ignore[arg-type]
)
compile_state = cast("UpdateDMLState", compile_state)
update_stmt = compile_state.statement # type: ignore[assignment]
): ...
@util.memoized_property
- def sql_compiler(self):
+ def sql_compiler(self) -> SQLCompiler:
return self.dialect.statement_compiler(
self.dialect, None, schema_translate_map=self.schema_translate_map
)
def render_default_string(self, default: Visitable | str) -> str:
if isinstance(default, str):
- return self.sql_compiler.render_literal_value( # type: ignore[no-any-return] # NOQA: E501
+ return self.sql_compiler.render_literal_value(
default, sqltypes.STRINGTYPE
)
else:
- return self.sql_compiler.process(default, literal_binds=True) # type: ignore[no-any-return] # NOQA: E501
+ return self.sql_compiler.process(default, literal_binds=True)
def visit_table_or_column_check_constraint(self, constraint, **kw):
if constraint.is_column_level:
class GenericTypeCompiler(TypeCompiler):
- def visit_FLOAT(
- self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
- ) -> str:
+ def visit_FLOAT(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "FLOAT"
- def visit_DOUBLE(
- self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
- ) -> str:
+ def visit_DOUBLE(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "DOUBLE"
def visit_DOUBLE_PRECISION(
self,
- type_: "sqltypes.DOUBLE_PRECISION[decimal.Decimal| float]",
+ type_: TypeEngine[Any],
**kw: Any,
) -> str:
return "DOUBLE PRECISION"
- def visit_REAL(
- self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
- ) -> str:
+ def visit_REAL(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "REAL"
- def visit_NUMERIC(
- self, type_: "sqltypes.Numeric[decimal.Decimal| float]", **kw: Any
- ) -> str:
+ def visit_NUMERIC(self, type_: "sqltypes.Numeric[Any]", **kw: Any) -> str:
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
"scale": type_.scale,
}
- def visit_INTEGER(self, type_: "sqltypes.Integer", **kw: Any) -> str:
+ def visit_INTEGER(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "INTEGER"
- def visit_SMALLINT(self, type_: "sqltypes.SmallInteger", **kw: Any) -> str:
+ def visit_SMALLINT(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "SMALLINT"
- def visit_BIGINT(self, type_: "sqltypes.BigInteger", **kw: Any) -> str:
+ def visit_BIGINT(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "BIGINT"
- def visit_TIMESTAMP(self, type_: "sqltypes.TIMESTAMP", **kw: Any) -> str:
+ def visit_TIMESTAMP(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "TIMESTAMP"
- def visit_DATETIME(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
+ def visit_DATETIME(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "DATETIME"
- def visit_DATE(self, type_: "sqltypes.Date", **kw: Any) -> str:
+ def visit_DATE(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "DATE"
- def visit_TIME(self, type_: "sqltypes.Time", **kw: Any) -> str:
+ def visit_TIME(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "TIME"
- def visit_CLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
+ def visit_CLOB(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "CLOB"
- def visit_NCLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
+ def visit_NCLOB(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "NCLOB"
def _render_string_type(
text += ' COLLATE "%s"' % type_.collation
return text
- def visit_CHAR(self, type_: "sqltypes.CHAR", **kw: Any) -> str:
+ def visit_CHAR(self, type_: sqltypes.String, **kw: Any) -> str:
return self._render_string_type(type_, "CHAR")
- def visit_NCHAR(self, type_: "sqltypes.NCHAR", **kw: Any) -> str:
+ def visit_NCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
return self._render_string_type(type_, "NCHAR")
- def visit_VARCHAR(self, type_: "sqltypes.String", **kw: Any) -> str:
+ def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
return self._render_string_type(type_, "VARCHAR")
- def visit_NVARCHAR(self, type_: "sqltypes.NVARCHAR", **kw: Any) -> str:
+ def visit_NVARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
return self._render_string_type(type_, "NVARCHAR")
- def visit_TEXT(self, type_: "sqltypes.Text", **kw: Any) -> str:
+ def visit_TEXT(self, type_: sqltypes.String, **kw: Any) -> str:
return self._render_string_type(type_, "TEXT")
- def visit_UUID(
- self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any
- ) -> str:
+ def visit_UUID(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "UUID"
- def visit_BLOB(self, type_: "sqltypes.LargeBinary", **kw: Any) -> str:
+ def visit_BLOB(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "BLOB"
- def visit_BINARY(self, type_: "sqltypes.BINARY", **kw: Any) -> str:
+ def visit_BINARY(self, type_: "sqltypes._Binary", **kw: Any) -> str:
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_VARBINARY(self, type_: "sqltypes.VARBINARY", **kw: Any) -> str:
+ def visit_VARBINARY(self, type_: "sqltypes._Binary", **kw: Any) -> str:
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_BOOLEAN(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
+ def visit_BOOLEAN(self, type_: TypeEngine[Any], **kw: Any) -> str:
return "BOOLEAN"
def visit_uuid(
else:
return self.visit_UUID(type_, **kw)
- def visit_large_binary(
- self, type_: "sqltypes.LargeBinary", **kw: Any
- ) -> str:
+ def visit_large_binary(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_BLOB(type_, **kw)
- def visit_boolean(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
+ def visit_boolean(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_BOOLEAN(type_, **kw)
- def visit_time(self, type_: "sqltypes.Time", **kw: Any) -> str:
+ def visit_time(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_TIME(type_, **kw)
- def visit_datetime(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
+ def visit_datetime(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_DATETIME(type_, **kw)
- def visit_date(self, type_: "sqltypes.Date", **kw: Any) -> str:
+ def visit_date(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_DATE(type_, **kw)
- def visit_big_integer(
- self, type_: "sqltypes.BigInteger", **kw: Any
- ) -> str:
+ def visit_big_integer(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_BIGINT(type_, **kw)
- def visit_small_integer(
- self, type_: "sqltypes.SmallInteger", **kw: Any
- ) -> str:
+ def visit_small_integer(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_SMALLINT(type_, **kw)
- def visit_integer(self, type_: "sqltypes.Integer", **kw: Any) -> str:
+ def visit_integer(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_INTEGER(type_, **kw)
- def visit_real(
- self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
- ) -> str:
+ def visit_real(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_REAL(type_, **kw)
- def visit_float(
- self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
- ) -> str:
+ def visit_float(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_FLOAT(type_, **kw)
- def visit_double(
- self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
- ) -> str:
+ def visit_double(self, type_: TypeEngine[Any], **kw: Any) -> str:
return self.visit_DOUBLE(type_, **kw)
- def visit_numeric(
- self, type_: "sqltypes.Numeric[decimal.Decimal | float]", **kw: Any
- ) -> str:
+ def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
return self.visit_NUMERIC(type_, **kw)
- def visit_string(self, type_: "sqltypes.String", **kw: Any) -> str:
+ def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
- def visit_unicode(self, type_: "sqltypes.Unicode", **kw: Any) -> str:
+ def visit_unicode(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
- def visit_text(self, type_: "sqltypes.Text", **kw: Any) -> str:
+ def visit_text(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(
- self, type_: "sqltypes.UnicodeText", **kw: Any
- ) -> str:
+ def visit_unicode_text(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_TEXT(type_, **kw)
- def visit_enum(self, type_: "sqltypes.Enum", **kw: Any) -> str:
+ def visit_enum(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
def __init__(
self,
- dialect: DefaultDialect,
+ dialect: Dialect,
initial_quote: str = '"',
final_quote: str | None = None,
escape_quote: str = '"',
"schema_translate_map dictionaries."
)
- d["_none"] = d[None]
+ d["_none"] = d[None] # type: ignore[index]
def replace(m):
name = m.group(2)
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
max_ = (
- self.dialect.max_index_name_length
+ self.dialect.max_index_name_length # type: ignore[attr-defined]
or self.dialect.max_identifier_length
)
return self._truncate_and_render_maxlen_name(
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
max_ = (
- self.dialect.max_constraint_name_length
+ self.dialect.max_constraint_name_length # type: ignore[attr-defined] # NOQA: E501
or self.dialect.max_identifier_length
)
return self._truncate_and_render_maxlen_name(
if len(name) > max_:
name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
else:
- self.dialect.validate_identifier(name)
+ self.dialect.validate_identifier(name) # type: ignore[attr-defined] # NOQA: E501
if not _alembic_quote:
return name
@overload
def format_table(
self,
- table: "Table | None",
+ table: "FromClause | None",
use_schema: bool,
name: str,
) -> str: ...
@overload
def format_table(
self,
- table: "Table",
+ table: "NamedFromClause",
use_schema: bool = True,
name: None = None,
) -> str: ...
def format_table(
self,
- table: "Table | None",
+ table: "FromClause | None",
use_schema: bool = True,
name: str | None = None,
) -> str:
"""Prepare a quoted table and schema name."""
if name is None:
assert table is not None
+ table = cast("NamedFromClause", table)
name = table.name
result = self.quote(name)
def format_column(
self,
- column: "Column[Any]",
+ column: ColumnElement[Any],
use_table: bool = False,
name: str | None = None,
table_name: str | None = None,
if name is None:
name = column.name
+ name = cast(str, name)
if anon_map is not None and isinstance(
name, elements._truncated_label
import typing
from typing import Any
from typing import Callable
+from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import Protocol
from typing import Sequence as typing_Sequence
from typing import Tuple
+from typing import TypeVar
from . import roles
from .base import _generative
from .base import Executable
from .base import SchemaVisitor
from .elements import ClauseElement
+from .schema import Column
from .. import exc
from .. import util
from ..util import topological
from .compiler import Compiled
from .compiler import DDLCompiler
from .elements import BindParameter
- from .schema import Column
from .schema import Constraint
from .schema import ForeignKeyConstraint
from .schema import Index
from ..engine.interfaces import Dialect
from ..engine.interfaces import SchemaTranslateMapType
+T = TypeVar("T", bound="SchemaItem")
+
class BaseDDLElement(ClauseElement):
"""The root of DDL constructs, including those that are sub-elements
)
-class _CreateDropBase(ExecutableDDLElement):
+class _CreateDropBase(ExecutableDDLElement, Generic[T]):
"""Base class for DDL constructs that represent CREATE and DROP or
equivalents.
def __init__(
self,
- element,
+ element: T,
):
self.element = self.target = element
self._ddl_if = getattr(element, "_ddl_if", None)
return False
-class _CreateBase(_CreateDropBase):
+class _CreateBase(_CreateDropBase[Any]):
def __init__(self, element, if_not_exists=False):
super().__init__(element)
self.if_not_exists = if_not_exists
-class _DropBase(_CreateDropBase):
+class _DropBase(_CreateDropBase[Any]):
def __init__(self, element, if_exists=False):
super().__init__(element)
self.if_exists = if_exists
"""Represent a CREATE INDEX statement."""
__visit_name__ = "create_index"
- element: "Index"
+ element: Index
- def __init__(self, element: "Index", if_not_exists: bool = False):
+ def __init__(self, element: Index, if_not_exists: bool = False):
"""Create a :class:`.Createindex` construct.
:param element: a :class:`_schema.Index` that's the subject
__visit_name__ = "drop_index"
- element: "Index"
+ element: Index
- def __init__(self, element: "Index", if_exists: bool = False):
+ def __init__(self, element: Index, if_exists: bool = False):
"""Create a :class:`.DropIndex` construct.
:param element: a :class:`_schema.Index` that's the subject
)
-class SetTableComment(_CreateDropBase):
+class SetTableComment(_CreateDropBase["Table"]):
"""Represent a COMMENT ON TABLE IS statement."""
__visit_name__ = "set_table_comment"
-class DropTableComment(_CreateDropBase):
+class DropTableComment(_CreateDropBase["Table"]):
"""Represent a COMMENT ON TABLE '' statement.
Note this varies a lot across database backends.
__visit_name__ = "drop_table_comment"
-class SetColumnComment(_CreateDropBase):
+class SetColumnComment(_CreateDropBase[Column[Any]]):
"""Represent a COMMENT ON COLUMN IS statement."""
__visit_name__ = "set_column_comment"
- element: "Column[Any]"
-class DropColumnComment(_CreateDropBase):
+class DropColumnComment(_CreateDropBase[Column[Any]]):
"""Represent a COMMENT ON COLUMN IS NULL statement."""
__visit_name__ = "drop_column_comment"
-class SetConstraintComment(_CreateDropBase):
+class SetConstraintComment(_CreateDropBase["Constraint"]):
"""Represent a COMMENT ON CONSTRAINT IS statement."""
__visit_name__ = "set_constraint_comment"
-class DropConstraintComment(_CreateDropBase):
+class DropConstraintComment(_CreateDropBase["Constraint"]):
"""Represent a COMMENT ON CONSTRAINT IS NULL statement."""
__visit_name__ = "drop_constraint_comment"
return process
def bind_processor(
- self, dialect: "Dialect"
+ self, dialect: Dialect
) -> _BindProcessorType[str] | None:
return None
__visit_name__ = "enum"
- enum_class: None | str | type[enum.StrEnum]
+ enum_class: None | str | type[enum.Enum]
- def __init__(self, *enums: object, **kw: Any):
+ def __init__(self, *enums: Union[str, type[enum.Enum]], **kw: Any):
r"""Construct an enum.
Keyword arguments which don't apply to a specific backend are ignored
.. versionchanged:: 2.0 This parameter now defaults to True.
"""
- self._enum_init(enums, kw) # type: ignore[arg-type]
+ self._enum_init(enums, kw)
@property
def _enums_argument(self):
return self.enums
def _enum_init(
- self, enums: Sequence[str | type[enum.StrEnum]], kw: dict[str, Any]
+ self, enums: Sequence[Union[str, type[enum.Enum]]], kw: dict[str, Any]
) -> None:
"""internal init for :class:`.Enum` and subclasses.
self.native_enum = kw.pop("native_enum", True)
self.create_constraint = kw.pop("create_constraint", False)
self.values_callable: (
- Callable[[type[enum.StrEnum]], Sequence[str]] | None
+ Callable[[type[enum.Enum]], Sequence[str]] | None
) = kw.pop("values_callable", None)
self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
length_arg = kw.pop("length", NO_ARG)
)
length = length_arg
- self._valid_lookup[None] = self._object_lookup[None] = None # type: ignore # noqa: E501
+ self._valid_lookup[None] = self._object_lookup[None] = None
super().__init__(length=length)
)
def _parse_into_values(
- self, enums: Sequence[str | type[enum.StrEnum]], kw: Any
- ) -> tuple[Sequence[str], Sequence[enum.StrEnum] | Sequence[str]]:
+ self, enums: Sequence[str | type[enum.Enum]], kw: Any
+ ) -> tuple[Sequence[str], Sequence[enum.Enum] | Sequence[str]]:
if not enums and "_enums" in kw:
enums = kw.pop("_enums")
def _setup_for_values(
self,
values: Sequence[str],
- objects: Sequence[enum.StrEnum] | Sequence[str],
+ objects: Sequence[enum.Enum] | Sequence[str],
kw: Any,
) -> None:
self.enums = list(values)
- self._valid_lookup: dict[str, str] = dict(
+ self._valid_lookup: dict[enum.Enum | str | None, str | None] = dict(
zip(reversed(objects), reversed(values))
)
- self._object_lookup: dict[str, str] = dict(zip(values, objects))
+ self._object_lookup: dict[str | None, enum.Enum | str | None] = dict(
+ zip(values, objects)
+ )
self._valid_lookup.update(
[
comparator_factory = Comparator
- def _object_value_for_elem(self, elem: str) -> str:
+ def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]:
try:
- return self._object_lookup[elem]
+ # Value will not be None beacuse key is not None
+ return self._object_lookup[elem] # type: ignore[return-value]
except KeyError as err:
raise LookupError(
"'%s' is not among the defined enum values. "
class VARBINARY(_Binary):
"""The SQL VARBINARY type."""
- length: int
+ length: Optional[int]
__visit_name__ = "VARBINARY"
if character_based_uuid:
if self.as_uuid:
- def process(value: Any) -> str:
+ def process(value):
if value is not None:
value = value.hex
- return value # type: ignore[no-any-return]
+ return value
return process
else:
- def process(value: Any) -> str:
+ def process(value):
if value is not None:
value = value.replace("-", "")
- return value # type: ignore[no-any-return]
+ return value
return process
else: