# Copyright (C) 2020 The Psycopg Team
-import re
import codecs
from datetime import date, datetime
+from typing import cast
from ..adapt import Dumper, Loader
from ..proto import AdaptContext
def __init__(self, oid: int, context: AdaptContext):
super().__init__(oid, context)
+ self._date_format = self._format_from_context()
- ds = self._get_datestyle()
- if ds == b"ISO":
- pass # Default: YMD
- elif ds == b"German":
- self.load = self.load_dmy # type: ignore
- elif ds == b"SQL" or ds == b"Postgres":
- self.load = self.load_mdy # type: ignore
-
- def load_ymd(self, data: bytes) -> date:
- try:
- return date(int(data[:4]), int(data[5:7]), int(data[8:]))
- except ValueError as e:
- exc = e
-
- return self._raise_error(data, exc)
-
- load = load_ymd
-
- def load_dmy(self, data: bytes) -> date:
+ def load(self, data: bytes) -> date:
try:
- return date(int(data[6:]), int(data[3:5]), int(data[:2]))
+ return datetime.strptime(
+ self._decode(data)[0], self._date_format
+ ).date()
except ValueError as e:
- exc = e
+ return self._raise_error(data, e)
- return self._raise_error(data, exc)
-
- def load_mdy(self, data: bytes) -> date:
- try:
- return date(int(data[6:]), int(data[:2]), int(data[3:5]))
- except ValueError as e:
- exc = e
-
- return self._raise_error(data, exc)
+ def _format_from_context(self) -> str:
+ ds = self._get_datestyle()
+ if ds.startswith(b"I"): # ISO
+ return "%Y-%m-%d"
+ elif ds.startswith(b"G"): # German
+ return "%d.%m.%Y"
+ elif ds.startswith(b"S"): # SQL
+ return "%d/%m/%Y" if ds.endswith(b"DMY") else "%m/%d/%Y"
+ elif ds.startswith(b"P"): # Postgres
+ return "%d-%m-%Y" if ds.endswith(b"DMY") else "%m-%d-%Y"
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
def _get_datestyle(self) -> bytes:
- """Return the PostgreSQL output datestyle of the connection."""
+ rv = b"ISO, DMY"
if self.connection:
ds = self.connection.pgconn.parameter_status(b"DateStyle")
if ds:
- return ds.split(b",", 1)[0]
+ rv = ds
- return b"ISO"
+ return rv
def _raise_error(self, data: bytes, exc: ValueError) -> date:
# Most likely we received a BC date, which Python doesn't support
# Otherwise the unexpected value is displayed in the exception.
if data.endswith(b"BC"):
- raise InterfaceError(
+ raise ValueError(
"Python doesn't support BC date:"
f" got {data.decode('utf8', 'replace')}"
)
- # Find the year from the date. This is not the fast path so we don't
- # need crazy speed.
+ # Find the year from the date. We check if >= Y10K only in ISO format,
+ # others are too silly to bother being polite.
ds = self._get_datestyle()
- if ds == b"ISO":
+ if ds.startswith(b"ISO"):
year = int(data.split(b"-", 1)[0])
- else:
- year = int(re.split(rb"[-/\.]", data)[-1])
-
- if year > 9999:
- raise InterfaceError(
- "Python date doesn't support years after 9999:"
- f" got {data.decode('utf8', 'replace')}"
- )
+ if year > 9999:
+ raise ValueError(
+ "Python date doesn't support years after 9999:"
+ f" got {data.decode('utf8', 'replace')}"
+ )
# We genuinely received something we cannot parse
raise exc
+
+
+@Loader.text(builtins["timestamp"].oid)
+class TimestampLoader(DateLoader):
+ def __init__(self, oid: int, context: AdaptContext):
+ super().__init__(oid, context)
+ self._no_micro_format = self._date_format.replace(".%f", "")
+
+ def load(self, data: bytes) -> datetime:
+ # check if the data contains microseconds
+ fmt = self._date_format if b"." in data[19:] else self._no_micro_format
+ try:
+ return datetime.strptime(self._decode(data)[0], fmt)
+ except ValueError as e:
+ return self._raise_error(data, e)
+
+ def _format_from_context(self) -> str:
+ ds = self._get_datestyle()
+ if ds.startswith(b"I"): # ISO
+ return "%Y-%m-%d %H:%M:%S.%f"
+ elif ds.startswith(b"G"): # German
+ return "%d.%m.%Y %H:%M:%S.%f"
+ elif ds.startswith(b"S"): # SQL
+ return (
+ "%d/%m/%Y %H:%M:%S.%f"
+ if ds.endswith(b"DMY")
+ else "%m/%d/%Y %H:%M:%S.%f"
+ )
+ elif ds.startswith(b"P"): # Postgres
+ return (
+ "%a %d %b %H:%M:%S.%f %Y"
+ if ds.endswith(b"DMY")
+ else "%a %b %d %H:%M:%S.%f %Y"
+ )
+ else:
+ raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
+
+ def _raise_error(self, data: bytes, exc: ValueError) -> datetime:
+ return cast(datetime, super()._raise_error(data, exc))
import datetime as dt
import pytest
-import psycopg3
from psycopg3.adapt import Format
# date tests
#
+
@pytest.mark.parametrize(
"val, expr",
[
@pytest.mark.xfail # TODO: binary load
@pytest.mark.parametrize(
- "val, expr", [(dt.date(2000, 1, 1), "'2000-01-01'::date")],
+ "val, expr", [(dt.date(2000, 1, 1), "'2000-01-01'::date")]
)
def test_load_date_binary(conn, val, expr):
cur = conn.cursor(format=Format.BINARY)
cur = conn.cursor()
cur.execute(f"set datestyle = {datestyle_out}, YMD")
cur.execute("select %s - 1", (dt.date.min,))
- with pytest.raises(psycopg3.InterfaceError):
+ with pytest.raises(ValueError):
cur.fetchone()[0]
cur = conn.cursor()
cur.execute(f"set datestyle = {datestyle_out}, YMD")
cur.execute("select %s + 1", (dt.date.max,))
- with pytest.raises(psycopg3.InterfaceError):
+ with pytest.raises(ValueError):
cur.fetchone()[0]
# datetime tests
#
+
@pytest.mark.parametrize(
"val, expr",
[
assert cur.fetchone()[0] is True
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("min", "'0001-01-01'"),
+ ("1000,1,1", "'1000-01-01'"),
+ ("2000,1,1", "'2000-01-01'"),
+ ("2000,1,2,3,4,5,6", "'2000-01-02 03:04:05.000006'"),
+ ("2000,1,2,3,4,5,678", "'2000-01-02 03:04:05.000678'"),
+ ("2000,1,2,3,0,0,456789", "'2000-01-02 03:00:00.456789'"),
+ ("2000,12,31", "'2000-12-31'"),
+ ("3000,1,1", "'3000-01-01'"),
+ ("max", "'9999-12-31 23:59:59.999999'"),
+ ],
+)
+@pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
+@pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
+def test_load_datetime(conn, val, expr, datestyle_in, datestyle_out):
+ cur = conn.cursor()
+ cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
+ val = (
+ dt.datetime(*map(int, val.split(",")))
+ if "," in val
+ else getattr(dt.datetime, val)
+ )
+ cur.execute("set timezone to '+02:00'")
+ cur.execute(f"select {expr}::timestamp")
+ assert cur.fetchone()[0] == val
+
+
#
# datetime+tz tests
#
+
@pytest.mark.parametrize(
"val, expr",
[