import collections
import collections.abc as collections_abc
import contextlib
+import decimal
from enum import IntEnum
import functools
import itertools
from typing import NamedTuple
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Pattern
from typing import Protocol
from typing import Sequence
from .base import Executable
from .base import NO_ARG
from .elements import ClauseElement
+from .elements import False_
+from .elements import Null
from .elements import quoted_name
+from .elements import True_
from .schema import Column
+from .schema import ForeignKeyConstraint
+from .schema import UniqueConstraint
+from .sqltypes import _UUID_RETURN
from .sqltypes import TupleType
+from .type_api import TypeDecorator
from .type_api import TypeEngine
+from .type_api import UserDefinedType
from .visitors import prefix_anon_map
from .visitors import Visitable
from .. import exc
from .cache_key import CacheKey
from .ddl import ExecutableDDLElement
from .dml import Insert
+ from .dml import Update
from .dml import UpdateBase
+ from .dml import UpdateDMLState
from .dml import ValuesBase
from .elements import _truncated_label
from .elements import BindParameter
from .elements import ColumnElement
from .elements import Label
from .functions import Function
+ from .schema import Constraint
+ from .schema import Index
+ from .schema import PrimaryKeyConstraint
from .schema import Table
from .selectable import AliasedReturnsRows
from .selectable import CompoundSelectState
from .selectable import SelectState
from .type_api import _BindProcessorType
from ..engine.cursor import CursorResultMetaData
+ from ..engine.default import DefaultDialect
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _DBAPIAnyExecuteParams
from ..engine.interfaces import _DBAPIMultiExecuteParams
from ..engine.interfaces import Dialect
from ..engine.interfaces import SchemaTranslateMapType
+
_FromHintsType = Dict["FromClause", str]
RESERVED_WORDS = {
if render_schema_translate:
self.string = self.preparer._render_schema_translates(
- self.string, schema_translate_map
+ self.string, schema_translate_map # type: ignore[arg-type]
)
self.state = CompilerState.STRING_APPLIED
return get
- def default_from(self):
+ def default_from(self) -> str:
"""Called when a SELECT statement has no froms, and no FROM clause is
to be appended.
return text
- def visit_null(self, expr, **kw):
+ def visit_null(self, expr: Null, **kw: Any) -> str:
return "NULL"
- def visit_true(self, expr, **kw):
+ def visit_true(self, expr: True_, **kw: Any) -> str:
if self.dialect.supports_native_boolean:
return "true"
else:
return "1"
- def visit_false(self, expr, **kw):
+ def visit_false(self, expr: False_, **kw: Any) -> str:
if self.dialect.supports_native_boolean:
return "false"
else:
% self.dialect.name
)
- def function_argspec(self, func, **kwargs):
+ def function_argspec(
+ self, func: functions.Function[Any], **kwargs: Any
+ ) -> str:
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(
)
def _generate_generic_binary(
- self, binary, opstring, eager_grouping=False, **kw
- ):
+ self,
+ binary: elements.BinaryExpression[Any],
+ opstring: str,
+ eager_grouping: bool = False,
+ **kw: Any,
+ ) -> str:
_in_operator_expression = kw.get("_in_operator_expression", False)
kw["_in_operator_expression"] = True
**kw,
)
- def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ def visit_regexp_match_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
raise exc.CompileError(
"%s dialect does not support regular expressions"
% self.dialect.name
)
- def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ def visit_not_regexp_match_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
raise exc.CompileError(
"%s dialect does not support regular expressions"
% self.dialect.name
)
- def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ def visit_regexp_replace_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
raise exc.CompileError(
"%s dialect does not support regular expression replacements"
% self.dialect.name
else:
return self.render_literal_value(value, bindparam.type)
- def render_literal_value(self, value, type_):
+ def render_literal_value(
+ self, value: str | None, type_: sqltypes.String
+ ) -> str:
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind parameters
def get_select_hint_text(self, byfroms):
return None
- def get_from_hint_text(self, table, text):
+ def get_from_hint_text(self, table: Any, text: str | None) -> str | None:
return None
def get_crud_hint_text(self, table, text):
else:
return "WITH"
- def get_select_precolumns(self, select, **kw):
+ def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
"""Called when building a ``SELECT`` statement, position is just
before column list.
return text
- def update_limit_clause(self, update_stmt):
+ def update_limit_clause(self, update_stmt: "Update") -> str | None:
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
"criteria within UPDATE"
)
- def visit_update(self, update_stmt, visiting_cte=None, **kw):
- compile_state = update_stmt._compile_state_factory(
- update_stmt, self, **kw
+ def visit_update(
+ self, update_stmt: "Update", visiting_cte: CTE | None = None, **kw: Any
+ ) -> str:
+ compile_state = update_stmt._compile_state_factory( # type: ignore
+ update_stmt, self, **kw # type: ignore
)
- update_stmt = compile_state.statement
+ compile_state = cast("UpdateDMLState", compile_state)
+ update_stmt = compile_state.statement # type: ignore[assignment]
if visiting_cte is not None:
kw["visiting_cte"] = visiting_cte
if self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt,
- self.implicit_returning or update_stmt._returning,
+ self.implicit_returning or update_stmt._returning, # type: ignore[arg-type] # NOQA: E501
populate_result_map=toplevel,
)
) and not self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt,
- self.implicit_returning or update_stmt._returning,
+ self.implicit_returning or update_stmt._returning, # type: ignore[arg-type] # noqa: E501
populate_result_map=toplevel,
)
return text
def delete_extra_from_clause(
- self, update_stmt, from_table, extra_froms, from_hints, **kw
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
)
def delete_extra_from_clause(
- self, update_stmt, from_table, extra_froms, from_hints, **kw
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
):
kw["asfrom"] = True
return ", " + ", ".join(
def visit_drop_view(self, drop, **kw):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
- def _verify_index_table(self, index):
+ def _verify_index_table(self, index: "Index") -> None:
if index.table is None:
raise exc.CompileError(
"Index '%s' is not associated with any table." % index.name
return text + self._prepared_index_name(index, include_schema=True)
- def _prepared_index_name(self, index, include_schema=False):
+ def _prepared_index_name(
+ self, index: "Index", include_schema: bool = False
+ ) -> str:
if index.table is not None:
effective_schema = self.preparer.schema_for_object(index.table)
else:
else:
schema_name = None
- index_name = self.preparer.format_index(index)
+ index_name: str = self.preparer.format_index(index)
if schema_name:
index_name = schema_name + "." + index_name
def post_create_table(self, table):
return ""
- def get_column_default_string(self, column):
+ def get_column_default_string(self, column: Column[Any]) -> str | None:
if isinstance(column.server_default, schema.DefaultClause):
return self.render_default_string(column.server_default.arg)
else:
return None
- def render_default_string(self, default):
+ def render_default_string(self, default: Visitable | str) -> str:
if isinstance(default, str):
- return self.sql_compiler.render_literal_value(
+ return self.sql_compiler.render_literal_value( # type: ignore[no-any-return] # NOQA: E501
default, sqltypes.STRINGTYPE
)
else:
- return self.sql_compiler.process(default, literal_binds=True)
+ return self.sql_compiler.process(default, literal_binds=True) # type: ignore[no-any-return] # NOQA: E501
def visit_table_or_column_check_constraint(self, constraint, **kw):
if constraint.is_column_level:
text += self.define_constraint_deferrability(constraint)
return text
- def visit_primary_key_constraint(self, constraint, **kw):
+ def visit_primary_key_constraint(
+ self, constraint: "PrimaryKeyConstraint", **kw: Any
+ ) -> str:
if len(constraint) == 0:
return ""
text = ""
return preparer.format_table(table)
- def visit_unique_constraint(self, constraint, **kw):
+ def visit_unique_constraint(
+ self, constraint: UniqueConstraint, **kw: Any
+ ) -> str:
if len(constraint) == 0:
return ""
text = ""
text += self.define_constraint_deferrability(constraint)
return text
- def define_unique_constraint_distinct(self, constraint, **kw):
+ def define_unique_constraint_distinct(
+ self, constraint: UniqueConstraint, **kw: Any
+ ) -> str:
return ""
- def define_constraint_cascades(self, constraint):
+ def define_constraint_cascades(
+ self, constraint: ForeignKeyConstraint
+ ) -> str:
text = ""
if constraint.ondelete is not None:
text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
)
return text
- def define_constraint_deferrability(self, constraint):
+ def define_constraint_deferrability(self, constraint: Constraint) -> str:
text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
class GenericTypeCompiler(TypeCompiler):
- def visit_FLOAT(self, type_, **kw):
+ def visit_FLOAT(
+ self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
+ ) -> str:
return "FLOAT"
- def visit_DOUBLE(self, type_, **kw):
+ def visit_DOUBLE(
+ self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
+ ) -> str:
return "DOUBLE"
- def visit_DOUBLE_PRECISION(self, type_, **kw):
+ def visit_DOUBLE_PRECISION(
+ self,
+ type_: "sqltypes.DOUBLE_PRECISION[decimal.Decimal| float]",
+ **kw: Any,
+ ) -> str:
return "DOUBLE PRECISION"
- def visit_REAL(self, type_, **kw):
+ def visit_REAL(
+ self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
+ ) -> str:
return "REAL"
- def visit_NUMERIC(self, type_, **kw):
+ def visit_NUMERIC(
+ self, type_: "sqltypes.Numeric[decimal.Decimal| float]", **kw: Any
+ ) -> str:
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
"scale": type_.scale,
}
- def visit_DECIMAL(self, type_, **kw):
+ def visit_DECIMAL(
+ self, type_: "sqltypes.DECIMAL[decimal.Decimal| float]", **kw: Any
+ ) -> str:
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
"scale": type_.scale,
}
- def visit_INTEGER(self, type_, **kw):
+ def visit_INTEGER(self, type_: "sqltypes.Integer", **kw: Any) -> str:
return "INTEGER"
- def visit_SMALLINT(self, type_, **kw):
+ def visit_SMALLINT(self, type_: "sqltypes.SmallInteger", **kw: Any) -> str:
return "SMALLINT"
- def visit_BIGINT(self, type_, **kw):
+ def visit_BIGINT(self, type_: "sqltypes.BigInteger", **kw: Any) -> str:
return "BIGINT"
- def visit_TIMESTAMP(self, type_, **kw):
+ def visit_TIMESTAMP(self, type_: "sqltypes.TIMESTAMP", **kw: Any) -> str:
return "TIMESTAMP"
- def visit_DATETIME(self, type_, **kw):
+ def visit_DATETIME(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
return "DATETIME"
- def visit_DATE(self, type_, **kw):
+ def visit_DATE(self, type_: "sqltypes.Date", **kw: Any) -> str:
return "DATE"
- def visit_TIME(self, type_, **kw):
+ def visit_TIME(self, type_: "sqltypes.Time", **kw: Any) -> str:
return "TIME"
- def visit_CLOB(self, type_, **kw):
+ def visit_CLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
return "CLOB"
- def visit_NCLOB(self, type_, **kw):
+ def visit_NCLOB(self, type_: "sqltypes.Text", **kw: Any) -> str:
return "NCLOB"
- def _render_string_type(self, type_, name, length_override=None):
+ def _render_string_type(
+ self,
+ type_: "sqltypes.String | sqltypes.Uuid[_UUID_RETURN]",
+ name: str,
+ length_override: int | None = None,
+ ) -> str:
text = name
if length_override:
text += "(%d)" % length_override
- elif type_.length:
- text += "(%d)" % type_.length
+ elif type_.length: # type: ignore[union-attr]
+ text += "(%d)" % type_.length # type: ignore[union-attr]
if type_.collation:
text += ' COLLATE "%s"' % type_.collation
return text
- def visit_CHAR(self, type_, **kw):
+ def visit_CHAR(self, type_: "sqltypes.CHAR", **kw: Any) -> str:
return self._render_string_type(type_, "CHAR")
- def visit_NCHAR(self, type_, **kw):
+ def visit_NCHAR(self, type_: "sqltypes.NCHAR", **kw: Any) -> str:
return self._render_string_type(type_, "NCHAR")
- def visit_VARCHAR(self, type_, **kw):
+ def visit_VARCHAR(self, type_: "sqltypes.String", **kw: Any) -> str:
return self._render_string_type(type_, "VARCHAR")
- def visit_NVARCHAR(self, type_, **kw):
+ def visit_NVARCHAR(self, type_: "sqltypes.NVARCHAR", **kw: Any) -> str:
return self._render_string_type(type_, "NVARCHAR")
- def visit_TEXT(self, type_, **kw):
+ def visit_TEXT(self, type_: "sqltypes.Text", **kw: Any) -> str:
return self._render_string_type(type_, "TEXT")
- def visit_UUID(self, type_, **kw):
+ def visit_UUID(
+ self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any
+ ) -> str:
return "UUID"
- def visit_BLOB(self, type_, **kw):
+ def visit_BLOB(self, type_: "sqltypes.LargeBinary", **kw: Any) -> str:
return "BLOB"
- def visit_BINARY(self, type_, **kw):
+ def visit_BINARY(self, type_: "sqltypes.BINARY", **kw: Any) -> str:
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_VARBINARY(self, type_, **kw):
+ def visit_VARBINARY(self, type_: "sqltypes.VARBINARY", **kw: Any) -> str:
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
- def visit_BOOLEAN(self, type_, **kw):
+ def visit_BOOLEAN(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
return "BOOLEAN"
- def visit_uuid(self, type_, **kw):
+ def visit_uuid(
+ self, type_: "sqltypes.Uuid[_UUID_RETURN]", **kw: Any
+ ) -> str:
if not type_.native_uuid or not self.dialect.supports_native_uuid:
return self._render_string_type(type_, "CHAR", length_override=32)
else:
return self.visit_UUID(type_, **kw)
- def visit_large_binary(self, type_, **kw):
+ def visit_large_binary(
+ self, type_: "sqltypes.LargeBinary", **kw: Any
+ ) -> str:
return self.visit_BLOB(type_, **kw)
- def visit_boolean(self, type_, **kw):
+ def visit_boolean(self, type_: "sqltypes.Boolean", **kw: Any) -> str:
return self.visit_BOOLEAN(type_, **kw)
- def visit_time(self, type_, **kw):
+ def visit_time(self, type_: "sqltypes.Time", **kw: Any) -> str:
return self.visit_TIME(type_, **kw)
- def visit_datetime(self, type_, **kw):
+ def visit_datetime(self, type_: "sqltypes.DateTime", **kw: Any) -> str:
return self.visit_DATETIME(type_, **kw)
- def visit_date(self, type_, **kw):
+ def visit_date(self, type_: "sqltypes.Date", **kw: Any) -> str:
return self.visit_DATE(type_, **kw)
- def visit_big_integer(self, type_, **kw):
+ def visit_big_integer(
+ self, type_: "sqltypes.BigInteger", **kw: Any
+ ) -> str:
return self.visit_BIGINT(type_, **kw)
- def visit_small_integer(self, type_, **kw):
+ def visit_small_integer(
+ self, type_: "sqltypes.SmallInteger", **kw: Any
+ ) -> str:
return self.visit_SMALLINT(type_, **kw)
- def visit_integer(self, type_, **kw):
+ def visit_integer(self, type_: "sqltypes.Integer", **kw: Any) -> str:
return self.visit_INTEGER(type_, **kw)
- def visit_real(self, type_, **kw):
+ def visit_real(
+ self, type_: "sqltypes.REAL[decimal.Decimal| float]", **kw: Any
+ ) -> str:
return self.visit_REAL(type_, **kw)
- def visit_float(self, type_, **kw):
+ def visit_float(
+ self, type_: "sqltypes.Float[decimal.Decimal| float]", **kw: Any
+ ) -> str:
return self.visit_FLOAT(type_, **kw)
- def visit_double(self, type_, **kw):
+ def visit_double(
+ self, type_: "sqltypes.Double[decimal.Decimal | float]", **kw: Any
+ ) -> str:
return self.visit_DOUBLE(type_, **kw)
- def visit_numeric(self, type_, **kw):
+ def visit_numeric(
+ self, type_: "sqltypes.Numeric[decimal.Decimal | float]", **kw: Any
+ ) -> str:
return self.visit_NUMERIC(type_, **kw)
- def visit_string(self, type_, **kw):
+ def visit_string(self, type_: "sqltypes.String", **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
- def visit_unicode(self, type_, **kw):
+ def visit_unicode(self, type_: "sqltypes.Unicode", **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
- def visit_text(self, type_, **kw):
+ def visit_text(self, type_: "sqltypes.Text", **kw: Any) -> str:
return self.visit_TEXT(type_, **kw)
- def visit_unicode_text(self, type_, **kw):
+ def visit_unicode_text(
+ self, type_: "sqltypes.UnicodeText", **kw: Any
+ ) -> str:
return self.visit_TEXT(type_, **kw)
- def visit_enum(self, type_, **kw):
+ def visit_enum(self, type_: "sqltypes.Enum", **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
"type on this Column?" % type_
)
- def visit_type_decorator(self, type_, **kw):
+ def visit_type_decorator(
+ self, type_: TypeDecorator[Any], **kw: Any
+ ) -> str:
return self.process(type_.type_engine(self.dialect), **kw)
- def visit_user_defined(self, type_, **kw):
+ def visit_user_defined(
+ self, type_: UserDefinedType[Any], **kw: Any
+ ) -> str:
return type_.get_col_spec(**kw)
def __init__(
self,
- dialect,
- initial_quote='"',
- final_quote=None,
- escape_quote='"',
- quote_case_sensitive_collations=True,
- omit_schema=False,
+ dialect: DefaultDialect,
+ initial_quote: str = '"',
+ final_quote: str | None = None,
+ escape_quote: str = '"',
+ quote_case_sensitive_collations: bool = True,
+ omit_schema: bool = False,
):
"""Construct a new ``IdentifierPreparer`` object.
prep._includes_none_schema_translate = includes_none
return prep
- def _render_schema_translates(self, statement, schema_translate_map):
+ def _render_schema_translates(
+ self, statement: str, schema_translate_map: SchemaTranslateMapType
+ ) -> str:
d = schema_translate_map
if None in d:
if not self._includes_none_schema_translate:
else:
return collation_name
- def format_sequence(self, sequence, use_schema=True):
+ def format_sequence(
+ self, sequence: schema.Sequence, use_schema: bool = True
+ ) -> str:
name = self.quote(sequence.name)
effective_schema = self.schema_for_object(sequence)
return ident
@util.preload_module("sqlalchemy.sql.naming")
- def format_constraint(self, constraint, _alembic_quote=True):
+ def format_constraint(
+ self, constraint: Constraint, _alembic_quote: bool = True
+ ) -> str | None:
naming = util.preloaded.sql_naming
if constraint.name is _NONE_NAME:
name, _alembic_quote=_alembic_quote
)
- def truncate_and_render_index_name(self, name, _alembic_quote=True):
+ def truncate_and_render_index_name(
+ self, name: str, _alembic_quote: bool = True
+ ) -> str | None:
# calculate these at format time so that ad-hoc changes
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
name, max_, _alembic_quote
)
- def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
+ def truncate_and_render_constraint_name(
+ self, name: str, _alembic_quote: bool = True
+ ) -> str | None:
# calculate these at format time so that ad-hoc changes
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
name, max_, _alembic_quote
)
- def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
+ def _truncate_and_render_maxlen_name(
+ self, name: str, max_: int, _alembic_quote: bool
+ ) -> str | None:
if isinstance(name, elements._truncated_label):
if len(name) > max_:
name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
def format_index(self, index):
return self.format_constraint(index)
- def format_table(self, table, use_schema=True, name=None):
- """Prepare a quoted table and schema name."""
+ @overload
+ def format_table(
+ self,
+ table: "Table | None",
+ use_schema: bool,
+ name: str,
+ ) -> str: ...
+
+ @overload
+ def format_table(
+ self,
+ table: "Table",
+ use_schema: bool = True,
+ name: None = None,
+ ) -> str: ...
+ def format_table(
+ self,
+ table: "Table | None",
+ use_schema: bool = True,
+ name: str | None = None,
+ ) -> str:
+ """Prepare a quoted table and schema name."""
if name is None:
+ assert table is not None
name = table.name
result = self.quote(name)
def format_column(
self,
- column,
- use_table=False,
- name=None,
- table_name=None,
- use_schema=False,
- anon_map=None,
- ):
+ column: "Column[Any]",
+ use_table: bool = False,
+ name: str | None = None,
+ table_name: str | None = None,
+ use_schema: bool = False,
+ anon_map: Mapping[str, Any] | None = None,
+ ) -> str:
"""Prepare a quoted column name."""
if name is None:
)
return r
- def unformat_identifiers(self, identifiers):
+ def unformat_identifiers(self, identifiers: str) -> list[str]:
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers