s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse interval {s!r}")
+ ye: int | bytes
+ mo: int | bytes
ye, mo, da, sgn, ho, mi, se = m.groups()
days = 0
seconds = 0.0
+ all_months = 0
if ye:
- days += 365 * int(ye)
+ ye = int(ye)
+ days += 365 * ye
+ all_months += 12 * ye
if mo:
- days += 30 * int(mo)
+ mo = int(mo)
+ days += 30 * mo
+ all_months += mo
if da:
days += int(da)
if sgn == b"-":
seconds = -seconds
+ # Postgres adds 0.25 days every 12 months to approximate leap years
+ if all_months >= 12:
+ seconds += (6 * 60 * 60) * (all_months // 12)
+ elif all_months <= -12:
+ seconds -= (6 * 60 * 60) * (all_months // -12)
+
try:
return timedelta(days=days, seconds=seconds)
except OverflowError as e:
def load(self, data: Buffer) -> timedelta:
micros, days, months = _unpack_interval(data)
+ hours = 0
if months > 0:
years, months = divmod(months, 12)
days = days + 30 * months + 365 * years
+ # Postgres adds 0.25 days every 12 months to approximate leap years
+ hours = 6 * years
elif months < 0:
years, months = divmod(-months, 12)
days = days - 30 * months - 365 * years
+ hours = -6 * years
try:
- return timedelta(days=days, microseconds=micros)
+ return timedelta(days=days, hours=hours, microseconds=micros)
except OverflowError as e:
raise DataError(f"can't parse interval: {e}") from None
("3723s,400000m", "1:2:3.4"),
("86399s,999999m", "23:59:59.999999"),
("30d", "30 day"),
- ("365d", "1 year"),
- ("-365d", "-1 year"),
- ("-730d", "-2 years"),
- ("1460d", "4 year"),
+ ("365d,6h", "1 year"),
+ ("-365d,-6h", "-1 year"),
+ ("-730d,-12h", "-2 years"),
+ ("1461d", "4 year"),
("30d", "1 month"),
("-30d", "-1 month"),
("60d", "2 month"),
("-90d", "-3 month"),
("186d", "6 mons 6 days"),
("174d", "6 mons -6 days"),
- ("736d", "2 years 6 days"),
- ("724d", "2 years -6 days"),
+ ("736d,12h", "2 years 6 days"),
+ ("724d,12h", "2 years -6 days"),
("330d", "1 years -1 month"),
("83063d,81640s,447000m", "1993534:40:40.447"),
("-1d,64800s", "41 days -990:00:00"),
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_load_interval(self, conn, val, expr, fmt_out):
cur = conn.cursor(binary=fmt_out)
- cur.execute(f"select '{expr}'::interval")
- assert cur.fetchone()[0] == as_td(val)
+ cur.execute(
+ "select %(i)s::interval, extract('epoch' from %(i)s::interval)::float8",
+ {"i": expr},
+ )
+ got, nsecs = cur.fetchone()
+ assert got == as_td(val)
+ assert nsecs == as_td(val).total_seconds()
+
+ @pytest.mark.parametrize("fmt_out", pq.Format)
+ def test_load_interval_leap_fraction(self, conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out)
+ for y in (-5, -4, -3, 3, 4, 5):
+ for m in [-13, -12, -11, 11, 12, 13]:
+ cur.execute(
+ "select extract('epoch' from %s::interval)::float8",
+ [f"{y} year {m} month"],
+ )
+ got = cur.fetchone()[0]
+
+ m = y * 12 + m
+ if m >= 0:
+ y, m = divmod(m, 12)
+ d = 365 * y + 30 * m
+ h = 6 * y
+ want = dt.timedelta(days=d, hours=h)
+ else:
+ y, m = divmod(-m, 12)
+ d = 365 * y + 30 * m
+ h = 6 * y
+ want = -dt.timedelta(days=d, hours=h)
+
+ assert got == want.total_seconds()
@crdb_skip_datestyle
@pytest.mark.xfail # weird interval outputs
if s in ("min", "max"):
return getattr(dt.timedelta, s)
- suffixes = {"d": "days", "s": "seconds", "m": "microseconds"}
+ suffixes = {"d": "days", "s": "seconds", "h": "hours", "m": "microseconds"}
kwargs = {}
for part in s.split(","):
kwargs[suffixes[part[-1]]] = int(part[:-1])
+ if "hours" in kwargs:
+ kwargs["seconds"] = kwargs.get("seconds", 0) + kwargs.pop("hours") * 60 * 60
+
return dt.timedelta(**kwargs)