def get_alembic_option(
self, name: str, default: Optional[str] = None
- ) -> Union[None, str, list[str], dict[str, str]]:
+ ) -> Union[None, str, list[str], dict[str, str], list[dict[str, str]]]:
"""Return an option from the "[alembic]" or "[tool.alembic]" section
of the configparser-parsed .ini file (e.g. ``alembic.ini``) or
toml-parsed ``pyproject.toml`` file.
if self.file_config.has_option(self.config_ini_section, name):
return self.file_config.get(self.config_ini_section, name)
else:
- USE_DEFAULT = object()
- value: Union[None, str, list[str], dict[str, str]] = (
- self.toml_alembic_config.get(name, USE_DEFAULT)
- )
- if value is USE_DEFAULT:
- return default
- if value is not None:
- if isinstance(value, str):
- value = value % (self.toml_args)
- elif isinstance(value, list):
+ return self._get_toml_config_value(name, default=default)
+
+ def _get_toml_config_value(
+ self, name: str, default: Optional[Any] = None
+ ) -> Union[None, str, list[str], dict[str, str], list[dict[str, str]]]:
+ USE_DEFAULT = object()
+ value: Union[None, str, list[str], dict[str, str]] = (
+ self.toml_alembic_config.get(name, USE_DEFAULT)
+ )
+ if value is USE_DEFAULT:
+ return default
+ if value is not None:
+ if isinstance(value, str):
+ value = value % (self.toml_args)
+ elif isinstance(value, list):
+ if value and isinstance(value[0], dict):
+ value = [
+ {k: v % (self.toml_args) for k, v in dv.items()}
+ for dv in value
+ ]
+ else:
value = cast(
"list[str]", [v % (self.toml_args) for v in value]
)
- elif isinstance(value, dict):
- value = cast(
- "dict[str, str]",
- {k: v % (self.toml_args) for k, v in value.items()},
- )
- else:
- raise util.CommandError("unsupported TOML value type")
- return value
+ elif isinstance(value, dict):
+ value = cast(
+ "dict[str, str]",
+ {k: v % (self.toml_args) for k, v in value.items()},
+ )
+ else:
+ raise util.CommandError("unsupported TOML value type")
+ return value
@util.memoized_property
def messaging_opts(self) -> MessagingOptions:
if x
]
else:
- return self.toml_alembic_config.get("version_locations", None)
+ return cast(
+ "list[str]",
+ self._get_toml_config_value("version_locations", None),
+ )
def get_prepend_sys_paths_list(self) -> Optional[list[str]]:
prepend_sys_path_str = self.file_config.get(
if x
]
else:
- return self.toml_alembic_config.get("prepend_sys_path", None)
+ return cast(
+ "list[str]",
+ self._get_toml_config_value("prepend_sys_path", None),
+ )
def get_hooks_list(self) -> list[PostWriteHookConfig]:
hooks: list[PostWriteHookConfig] = []
if not self.file_config.has_section("post_write_hooks"):
- hook_config = self.toml_alembic_config.get("post_write_hooks", {})
- for cfg in hook_config:
+ toml_hook_config = cast(
+ "list[dict[str, str]]",
+ self._get_toml_config_value("post_write_hooks", []),
+ )
+ for cfg in toml_hook_config:
opts = dict(cfg)
opts["_hook_name"] = opts.pop("name")
hooks.append(opts)
else:
_split_on_space_comma = re.compile(r", *|(?: +)")
- hook_config = self.get_section("post_write_hooks", {})
- names = _split_on_space_comma.split(hook_config.get("hooks", ""))
+ ini_hook_config = self.get_section("post_write_hooks", {})
+ names = _split_on_space_comma.split(
+ ini_hook_config.get("hooks", "")
+ )
for name in names:
if not name:
continue
opts = {
- key[len(name) + 1 :]: hook_config[key]
- for key in hook_config
+ key[len(name) + 1 :]: ini_hook_config[key]
+ for key in ini_hook_config
if key.startswith(name + ".")
}
eq_(cfg.get_main_option("asdf"), "back_at_ya")
+ def test_script_location(self, pyproject_only_env):
+ cfg = pyproject_only_env
+ with cfg._toml_file_path.open("wb") as file_:
+ file_.write(
+ rb"""
+
+[tool.alembic]
+script_location = "%(here)s/scripts"
+
+"""
+ )
+
+ new_cfg = config.Config(
+ file_=cfg.config_file_name, toml_file=cfg._toml_file_path
+ )
+ sd = ScriptDirectory.from_config(new_cfg)
+ eq_(
+ pathlib.Path(sd.dir),
+ pathlib.Path(_get_staging_directory(), "scripts").absolute(),
+ )
+
+ def test_version_locations(self, pyproject_only_env):
+
+ cfg = pyproject_only_env
+ with cfg._toml_file_path.open("ba") as file_:
+ file_.write(
+ b"""
+version_locations = [
+ "%(here)s/foo/bar"
+]
+"""
+ )
+
+ if "toml_alembic_config" in cfg.__dict__:
+ cfg.__dict__.pop("toml_alembic_config")
+
+ eq_(
+ cfg.get_version_locations_list(),
+ [
+ pathlib.Path(_get_staging_directory(), "foo/bar")
+ .absolute()
+ .as_posix()
+ ],
+ )
+
+ def test_prepend_sys_path(self, pyproject_only_env):
+
+ cfg = pyproject_only_env
+ with cfg._toml_file_path.open("wb") as file_:
+ file_.write(
+ rb"""
+
+[tool.alembic]
+script_location = "%(here)s/scripts"
+
+prepend_sys_path = [
+ ".",
+ "%(here)s/path/to/python",
+ "c:\\some\\path"
+]
+"""
+ )
+
+ if "toml_alembic_config" in cfg.__dict__:
+ cfg.__dict__.pop("toml_alembic_config")
+
+ eq_(
+ cfg.get_prepend_sys_paths_list(),
+ [
+ ".",
+ pathlib.Path(_get_staging_directory(), "path/to/python")
+ .absolute()
+ .as_posix(),
+ r"c:\some\path",
+ ],
+ )
+
+ def test_write_hooks(self, pyproject_only_env):
+
+ cfg = pyproject_only_env
+ with cfg._toml_file_path.open("wb") as file_:
+ file_.write(
+ rb"""
+
+[tool.alembic]
+script_location = "%(here)s/scripts"
+
+[[tool.alembic.post_write_hooks]]
+name = "myhook"
+type = "exec"
+executable = "%(here)s/.venv/bin/ruff"
+options = "-l 79 REVISION_SCRIPT_FILENAME"
+
+"""
+ )
+
+ if "toml_alembic_config" in cfg.__dict__:
+ cfg.__dict__.pop("toml_alembic_config")
+
+ eq_(
+ cfg.get_hooks_list(),
+ [
+ {
+ "type": "exec",
+ "executable": (
+ cfg._toml_file_path.absolute().parent
+ / ".venv/bin/ruff"
+ ).as_posix(),
+ "options": "-l 79 REVISION_SCRIPT_FILENAME",
+ "_hook_name": "myhook",
+ }
+ ],
+ )
+
+ def test_string_list(self, pyproject_only_env):
+
+ cfg = pyproject_only_env
+ with cfg._toml_file_path.open("wb") as file_:
+ file_.write(
+ rb"""
+
+[tool.alembic]
+script_location = "%(here)s/scripts"
+
+my_list = [
+ "one",
+ "two %(here)s three"
+]
+
+"""
+ )
+ if "toml_alembic_config" in cfg.__dict__:
+ cfg.__dict__.pop("toml_alembic_config")
+
+ eq_(
+ cfg.get_alembic_option("my_list"),
+ [
+ "one",
+ f"two {cfg._toml_file_path.absolute().parent.as_posix()} "
+ "three",
+ ],
+ )
+
class StdoutOutputEncodingTest(TestBase):
def test_plain(self):