362 lines
12 KiB
Python
362 lines
12 KiB
Python
"""
|
|
Helper object to transform values between Python and PostgreSQL
|
|
|
|
Python implementation of the object. Use the `_transformer module to import
|
|
the right implementation (Python or C). The public place where the object
|
|
is exported is `psycopg.adapt` (which we may not use to avoid circular
|
|
dependencies problems).
|
|
"""
|
|
|
|
# Copyright (C) 2020 The Psycopg Team
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Sequence, DefaultDict, TYPE_CHECKING
|
|
from collections import defaultdict
|
|
|
|
from . import pq
|
|
from . import abc
|
|
from . import errors as e
|
|
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, NoneType
|
|
from .rows import Row, RowMaker
|
|
from ._oids import INVALID_OID, TEXT_OID
|
|
from ._compat import TypeAlias
|
|
from ._encodings import conn_encoding
|
|
|
|
if TYPE_CHECKING:
|
|
from .abc import DumperKey # noqa: F401
|
|
from .adapt import AdaptersMap
|
|
from .pq.abc import PGresult
|
|
from ._connection_base import BaseConnection
|
|
|
|
DumperCache: TypeAlias = "dict[DumperKey, abc.Dumper]"
|
|
OidDumperCache: TypeAlias = "dict[int, abc.Dumper]"
|
|
LoaderCache: TypeAlias = "dict[int, abc.Loader]"
|
|
|
|
TEXT = pq.Format.TEXT
|
|
PY_TEXT = PyFormat.TEXT
|
|
|
|
|
|
class Transformer(AdaptContext):
|
|
"""
|
|
An object that can adapt efficiently between Python and PostgreSQL.
|
|
|
|
The life cycle of the object is the query, so it is assumed that attributes
|
|
such as the server version or the connection encoding will not change. The
|
|
object have its state so adapting several values of the same type can be
|
|
optimised.
|
|
|
|
"""
|
|
|
|
__module__ = "psycopg.adapt"
|
|
|
|
__slots__ = """
|
|
types formats
|
|
_conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
|
|
_oid_dumpers _oid_types _row_dumpers _row_loaders
|
|
""".split()
|
|
|
|
types: tuple[int, ...] | None
|
|
formats: list[pq.Format] | None
|
|
|
|
_adapters: AdaptersMap
|
|
_pgresult: PGresult | None
|
|
_none_oid: int
|
|
|
|
def __init__(self, context: AdaptContext | None = None):
|
|
self._pgresult = self.types = self.formats = None
|
|
|
|
# WARNING: don't store context, or you'll create a loop with the Cursor
|
|
if context:
|
|
self._adapters = context.adapters
|
|
self._conn = context.connection
|
|
else:
|
|
from . import postgres
|
|
|
|
self._adapters = postgres.adapters
|
|
self._conn = None
|
|
|
|
# mapping fmt, class -> Dumper instance
|
|
self._dumpers: DefaultDict[PyFormat, DumperCache]
|
|
self._dumpers = defaultdict(dict)
|
|
|
|
# mapping fmt, oid -> Dumper instance
|
|
# Not often used, so create it only if needed.
|
|
self._oid_dumpers: tuple[OidDumperCache, OidDumperCache] | None
|
|
self._oid_dumpers = None
|
|
|
|
# mapping fmt, oid -> Loader instance
|
|
self._loaders: tuple[LoaderCache, LoaderCache] = ({}, {})
|
|
|
|
self._row_dumpers: list[abc.Dumper] | None = None
|
|
|
|
# sequence of load functions from value to python
|
|
# the length of the result columns
|
|
self._row_loaders: list[LoadFunc] = []
|
|
|
|
# mapping oid -> type sql representation
|
|
self._oid_types: dict[int, bytes] = {}
|
|
|
|
self._encoding = ""
|
|
|
|
@classmethod
|
|
def from_context(cls, context: AdaptContext | None) -> Transformer:
|
|
"""
|
|
Return a Transformer from an AdaptContext.
|
|
|
|
If the context is a Transformer instance, just return it.
|
|
"""
|
|
if isinstance(context, Transformer):
|
|
return context
|
|
else:
|
|
return cls(context)
|
|
|
|
@property
|
|
def connection(self) -> BaseConnection[Any] | None:
|
|
return self._conn
|
|
|
|
@property
|
|
def encoding(self) -> str:
|
|
if not self._encoding:
|
|
self._encoding = conn_encoding(self.connection)
|
|
return self._encoding
|
|
|
|
@property
|
|
def adapters(self) -> AdaptersMap:
|
|
return self._adapters
|
|
|
|
@property
|
|
def pgresult(self) -> PGresult | None:
|
|
return self._pgresult
|
|
|
|
def set_pgresult(
|
|
self,
|
|
result: PGresult | None,
|
|
*,
|
|
set_loaders: bool = True,
|
|
format: pq.Format | None = None,
|
|
) -> None:
|
|
self._pgresult = result
|
|
|
|
if not result:
|
|
self._nfields = self._ntuples = 0
|
|
if set_loaders:
|
|
self._row_loaders = []
|
|
return
|
|
|
|
self._ntuples = result.ntuples
|
|
nf = self._nfields = result.nfields
|
|
|
|
if not set_loaders:
|
|
return
|
|
|
|
if not nf:
|
|
self._row_loaders = []
|
|
return
|
|
|
|
fmt: pq.Format
|
|
fmt = result.fformat(0) if format is None else format # type: ignore
|
|
self._row_loaders = [
|
|
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
|
|
]
|
|
|
|
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
|
|
self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
|
|
self.types = tuple(types)
|
|
self.formats = [format] * len(types)
|
|
|
|
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
|
|
self._row_loaders = [self.get_loader(oid, format).load for oid in types]
|
|
|
|
def dump_sequence(
|
|
self, params: Sequence[Any], formats: Sequence[PyFormat]
|
|
) -> Sequence[Buffer | None]:
|
|
nparams = len(params)
|
|
out: list[Buffer | None] = [None] * nparams
|
|
|
|
# If we have dumpers, it means set_dumper_types had been called, in
|
|
# which case self.types and self.formats are set to sequences of the
|
|
# right size.
|
|
if self._row_dumpers:
|
|
for i in range(nparams):
|
|
param = params[i]
|
|
if param is not None:
|
|
out[i] = self._row_dumpers[i].dump(param)
|
|
return out
|
|
|
|
types = [self._get_none_oid()] * nparams
|
|
pqformats = [TEXT] * nparams
|
|
|
|
for i in range(nparams):
|
|
param = params[i]
|
|
if param is None:
|
|
continue
|
|
dumper = self.get_dumper(param, formats[i])
|
|
out[i] = dumper.dump(param)
|
|
types[i] = dumper.oid
|
|
pqformats[i] = dumper.format
|
|
|
|
self.types = tuple(types)
|
|
self.formats = pqformats
|
|
|
|
return out
|
|
|
|
def as_literal(self, obj: Any) -> bytes:
|
|
dumper = self.get_dumper(obj, PY_TEXT)
|
|
rv = dumper.quote(obj)
|
|
# If the result is quoted, and the oid not unknown or text,
|
|
# add an explicit type cast.
|
|
# Check the last char because the first one might be 'E'.
|
|
oid = dumper.oid
|
|
if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
|
|
try:
|
|
type_sql = self._oid_types[oid]
|
|
except KeyError:
|
|
ti = self.adapters.types.get(oid)
|
|
if ti:
|
|
if oid < 8192:
|
|
# builtin: prefer "timestamptz" to "timestamp with time zone"
|
|
type_sql = ti.name.encode(self.encoding)
|
|
else:
|
|
type_sql = ti.regtype.encode(self.encoding)
|
|
if oid == ti.array_oid:
|
|
type_sql += b"[]"
|
|
else:
|
|
type_sql = b""
|
|
self._oid_types[oid] = type_sql
|
|
|
|
if type_sql:
|
|
rv = b"%s::%s" % (rv, type_sql)
|
|
|
|
if not isinstance(rv, bytes):
|
|
rv = bytes(rv)
|
|
return rv
|
|
|
|
def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper:
|
|
"""
|
|
Return a Dumper instance to dump `!obj`.
|
|
"""
|
|
# Normally, the type of the object dictates how to dump it
|
|
key = type(obj)
|
|
|
|
# Reuse an existing Dumper class for objects of the same type
|
|
cache = self._dumpers[format]
|
|
try:
|
|
dumper = cache[key]
|
|
except KeyError:
|
|
# If it's the first time we see this type, look for a dumper
|
|
# configured for it.
|
|
try:
|
|
dcls = self.adapters.get_dumper(key, format)
|
|
except e.ProgrammingError as ex:
|
|
raise ex from None
|
|
else:
|
|
cache[key] = dumper = dcls(key, self)
|
|
|
|
# Check if the dumper requires an upgrade to handle this specific value
|
|
key1 = dumper.get_key(obj, format)
|
|
if key1 is key:
|
|
return dumper
|
|
|
|
# If it does, ask the dumper to create its own upgraded version
|
|
try:
|
|
return cache[key1]
|
|
except KeyError:
|
|
dumper = cache[key1] = dumper.upgrade(obj, format)
|
|
return dumper
|
|
|
|
def _get_none_oid(self) -> int:
|
|
try:
|
|
return self._none_oid
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
|
|
except KeyError:
|
|
raise e.InterfaceError("None dumper not found")
|
|
|
|
return rv
|
|
|
|
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper:
|
|
"""
|
|
Return a Dumper to dump an object to the type with given oid.
|
|
"""
|
|
if not self._oid_dumpers:
|
|
self._oid_dumpers = ({}, {})
|
|
|
|
# Reuse an existing Dumper class for objects of the same type
|
|
cache = self._oid_dumpers[format]
|
|
try:
|
|
return cache[oid]
|
|
except KeyError:
|
|
# If it's the first time we see this type, look for a dumper
|
|
# configured for it.
|
|
dcls = self.adapters.get_dumper_by_oid(oid, format)
|
|
cache[oid] = dumper = dcls(NoneType, self)
|
|
|
|
return dumper
|
|
|
|
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]:
|
|
res = self._pgresult
|
|
if not res:
|
|
raise e.InterfaceError("result not set")
|
|
|
|
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
|
|
raise e.InterfaceError(
|
|
f"rows must be included between 0 and {self._ntuples}"
|
|
)
|
|
|
|
records = []
|
|
for row in range(row0, row1):
|
|
record: list[Any] = [None] * self._nfields
|
|
for col in range(self._nfields):
|
|
val = res.get_value(row, col)
|
|
if val is not None:
|
|
record[col] = self._row_loaders[col](val)
|
|
records.append(make_row(record))
|
|
|
|
return records
|
|
|
|
def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None:
|
|
res = self._pgresult
|
|
if not res:
|
|
return None
|
|
|
|
if not 0 <= row < self._ntuples:
|
|
return None
|
|
|
|
record: list[Any] = [None] * self._nfields
|
|
for col in range(self._nfields):
|
|
val = res.get_value(row, col)
|
|
if val is not None:
|
|
record[col] = self._row_loaders[col](val)
|
|
|
|
return make_row(record)
|
|
|
|
def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]:
|
|
if len(self._row_loaders) != len(record):
|
|
raise e.ProgrammingError(
|
|
f"cannot load sequence of {len(record)} items:"
|
|
f" {len(self._row_loaders)} loaders registered"
|
|
)
|
|
|
|
return tuple(
|
|
(self._row_loaders[i](val) if val is not None else None)
|
|
for i, val in enumerate(record)
|
|
)
|
|
|
|
def get_loader(self, oid: int, format: pq.Format) -> abc.Loader:
|
|
try:
|
|
return self._loaders[format][oid]
|
|
except KeyError:
|
|
pass
|
|
|
|
loader_cls = self._adapters.get_loader(oid, format)
|
|
if not loader_cls:
|
|
loader_cls = self._adapters.get_loader(INVALID_OID, format)
|
|
if not loader_cls:
|
|
raise e.InterfaceError("unknown oid loader not found")
|
|
loader = self._loaders[format][oid] = loader_cls(oid, self)
|
|
return loader
|