feat: add comprehensive GitHub workflow and development tools

This commit is contained in:
Stiftung Development
2025-09-06 18:31:54 +02:00
commit ab23d7187e
10224 changed files with 2075210 additions and 0 deletions

View File

@@ -0,0 +1,121 @@
"""
psycopg -- PostgreSQL database adapter for Python
"""
# Copyright (C) 2020 The Psycopg Team
import logging
from . import pq # noqa: F401 import early to stabilize side effects
from . import types
from . import postgres
from ._tpc import Xid
from .copy import Copy, AsyncCopy
from ._enums import IsolationLevel
from .cursor import Cursor
from .errors import Warning, Error, InterfaceError, DatabaseError
from .errors import DataError, OperationalError, IntegrityError
from .errors import InternalError, ProgrammingError, NotSupportedError
from ._column import Column
from ._pipeline import Pipeline, AsyncPipeline
from .connection import Connection
from .transaction import Rollback, Transaction, AsyncTransaction
from .cursor_async import AsyncCursor
from ._capabilities import Capabilities, capabilities
from .server_cursor import AsyncServerCursor, ServerCursor
from .client_cursor import AsyncClientCursor, ClientCursor
from .raw_cursor import AsyncRawCursor, RawCursor
from .raw_cursor import AsyncRawServerCursor, RawServerCursor
from ._connection_base import BaseConnection, Notify
from ._connection_info import ConnectionInfo
from .connection_async import AsyncConnection
from . import dbapi20
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
from .dbapi20 import Timestamp, TimestampFromTicks
from .version import __version__ as __version__ # noqa: F401
# Set the logger to a quiet default, can be enabled if needed
logger = logging.getLogger("psycopg")
if logger.level == logging.NOTSET:
logger.setLevel(logging.WARNING)
# DBAPI compliance
connect = Connection.connect
apilevel = "2.0"
threadsafety = 2
paramstyle = "pyformat"
# register default adapters for PostgreSQL
adapters = postgres.adapters # exposed by the package
postgres.register_default_types(adapters.types)
postgres.register_default_adapters(adapters)
# After the default ones, because these can deal with the bytea oid better
dbapi20.register_dbapi20_adapters(adapters)
# Must come after all the types have been registered
types.array.register_all_arrays(adapters)
# Note: defining the exported methods helps both Sphynx in documenting that
# this is the canonical place to obtain them and should be used by MyPy too,
# so that function signatures are consistent with the documentation.
__all__ = [
"AsyncClientCursor",
"AsyncConnection",
"AsyncCopy",
"AsyncCursor",
"AsyncPipeline",
"AsyncRawCursor",
"AsyncRawServerCursor",
"AsyncServerCursor",
"AsyncTransaction",
"BaseConnection",
"Capabilities",
"capabilities",
"ClientCursor",
"Column",
"Connection",
"ConnectionInfo",
"Copy",
"Cursor",
"IsolationLevel",
"Notify",
"Pipeline",
"RawCursor",
"RawServerCursor",
"Rollback",
"ServerCursor",
"Transaction",
"Xid",
# DBAPI exports
"connect",
"apilevel",
"threadsafety",
"paramstyle",
"Warning",
"Error",
"InterfaceError",
"DatabaseError",
"DataError",
"OperationalError",
"IntegrityError",
"InternalError",
"ProgrammingError",
"NotSupportedError",
# DBAPI type constructors and singletons
"Binary",
"Date",
"DateFromTicks",
"Time",
"TimeFromTicks",
"Timestamp",
"TimestampFromTicks",
"BINARY",
"DATETIME",
"NUMBER",
"ROWID",
"STRING",
]

View File

@@ -0,0 +1,106 @@
"""
Utilities to ease the differences between async and sync code.
These object offer a similar interface between sync and async versions; the
script async_to_sync.py will replace the async names with the sync names
when generating the sync version.
"""
# Copyright (C) 2023 The Psycopg Team
from __future__ import annotations
import queue
import asyncio
import threading
from typing import Any, Callable, Coroutine, TYPE_CHECKING
from ._compat import TypeAlias, TypeVar
Worker: TypeAlias = threading.Thread
AWorker: TypeAlias = "asyncio.Task[None]"
T = TypeVar("T")
# Hack required on Python 3.8 because subclassing Queue[T] fails at runtime.
# https://stackoverflow.com/questions/45414066/mypy-how-to-define-a-generic-subclass
if TYPE_CHECKING:
_GQueue: TypeAlias = queue.Queue
_AGQueue: TypeAlias = asyncio.Queue
else:
class FakeGenericMeta(type):
def __getitem__(self, item):
return self
class _GQueue(queue.Queue, metaclass=FakeGenericMeta):
pass
class _AGQueue(asyncio.Queue, metaclass=FakeGenericMeta):
pass
class Queue(_GQueue[T]):
"""
A Queue subclass with an interruptible get() method.
"""
def get(self, block: bool = True, timeout: float | None = None) -> T:
# Always specify a timeout to make the wait interruptible.
if timeout is None:
timeout = 24.0 * 60.0 * 60.0
return super().get(block=block, timeout=timeout)
class AQueue(_AGQueue[T]):
pass
def aspawn(
f: Callable[..., Coroutine[Any, Any, None]],
args: tuple[Any, ...] = (),
name: str | None = None,
) -> asyncio.Task[None]:
"""
Equivalent to asyncio.create_task.
"""
return asyncio.create_task(f(*args), name=name)
def spawn(
f: Callable[..., Any],
args: tuple[Any, ...] = (),
name: str | None = None,
) -> threading.Thread:
"""
Equivalent to creating and running a daemon thread.
"""
t = threading.Thread(target=f, args=args, name=name, daemon=True)
t.start()
return t
async def agather(*tasks: asyncio.Task[Any], timeout: float | None = None) -> None:
"""
Equivalent to asyncio.gather or Thread.join()
"""
wait = asyncio.gather(*tasks)
try:
if timeout is not None:
await asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
else:
await wait
except asyncio.TimeoutError:
pass
else:
return
def gather(*tasks: threading.Thread, timeout: float | None = None) -> None:
"""
Equivalent to asyncio.gather or Thread.join()
"""
for t in tasks:
if not t.is_alive():
continue
t.join(timeout)

View File

@@ -0,0 +1,296 @@
"""
Mapping from types/oids to Dumpers/Loaders
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from typing import Any, cast, TYPE_CHECKING
from . import pq
from . import errors as e
from .abc import Dumper, Loader
from ._enums import PyFormat as PyFormat
from ._compat import TypeVar
from ._cmodule import _psycopg
from ._typeinfo import TypesRegistry
if TYPE_CHECKING:
from ._connection_base import BaseConnection
RV = TypeVar("RV")
class AdaptersMap:
r"""
Establish how types should be converted between Python and PostgreSQL in
an `~psycopg.abc.AdaptContext`.
`!AdaptersMap` maps Python types to `~psycopg.adapt.Dumper` classes to
define how Python types are converted to PostgreSQL, and maps OIDs to
`~psycopg.adapt.Loader` classes to establish how query results are
converted to Python.
Every `!AdaptContext` object has an underlying `!AdaptersMap` defining how
types are converted in that context, exposed as the
`~psycopg.abc.AdaptContext.adapters` attribute: changing such map allows
to customise adaptation in a context without changing separated contexts.
When a context is created from another context (for instance when a
`~psycopg.Cursor` is created from a `~psycopg.Connection`), the parent's
`!adapters` are used as template for the child's `!adapters`, so that every
cursor created from the same connection use the connection's types
configuration, but separate connections have independent mappings.
Once created, `!AdaptersMap` are independent. This means that objects
already created are not affected if a wider scope (e.g. the global one) is
changed.
The connections adapters are initialised using a global `!AdptersMap`
template, exposed as `psycopg.adapters`: changing such mapping allows to
customise the type mapping for every connections created afterwards.
The object can start empty or copy from another object of the same class.
Copies are copy-on-write: if the maps are updated make a copy. This way
extending e.g. global map by a connection or a connection map from a cursor
is cheap: a copy is only made on customisation.
"""
__module__ = "psycopg.adapt"
types: TypesRegistry
_dumpers: dict[PyFormat, dict[type | str, type[Dumper]]]
_dumpers_by_oid: list[dict[int, type[Dumper]]]
_loaders: list[dict[int, type[Loader]]]
# Record if a dumper or loader has an optimised version.
_optimised: dict[type, type] = {}
def __init__(
self,
template: AdaptersMap | None = None,
types: TypesRegistry | None = None,
):
if template:
self._dumpers = template._dumpers.copy()
self._own_dumpers = _dumpers_shared.copy()
template._own_dumpers = _dumpers_shared.copy()
self._dumpers_by_oid = template._dumpers_by_oid[:]
self._own_dumpers_by_oid = [False, False]
template._own_dumpers_by_oid = [False, False]
self._loaders = template._loaders[:]
self._own_loaders = [False, False]
template._own_loaders = [False, False]
self.types = TypesRegistry(template.types)
else:
self._dumpers = {fmt: {} for fmt in PyFormat}
self._own_dumpers = _dumpers_owned.copy()
self._dumpers_by_oid = [{}, {}]
self._own_dumpers_by_oid = [True, True]
self._loaders = [{}, {}]
self._own_loaders = [True, True]
self.types = types or TypesRegistry()
# implement the AdaptContext protocol too
@property
def adapters(self) -> AdaptersMap:
return self
@property
def connection(self) -> BaseConnection[Any] | None:
return None
def register_dumper(self, cls: type | str | None, dumper: type[Dumper]) -> None:
"""
Configure the context to use `!dumper` to convert objects of type `!cls`.
If two dumpers with different `~Dumper.format` are registered for the
same type, the last one registered will be chosen when the query
doesn't specify a format (i.e. when the value is used with a ``%s``
"`~PyFormat.AUTO`" placeholder).
:param cls: The type to manage.
:param dumper: The dumper to register for `!cls`.
If `!cls` is specified as string it will be lazy-loaded, so that it
will be possible to register it without importing it before. In this
case it should be the fully qualified name of the object (e.g.
``"uuid.UUID"``).
If `!cls` is None, only use the dumper when looking up using
`get_dumper_by_oid()`, which happens when we know the Postgres type to
adapt to, but not the Python type that will be adapted (e.g. in COPY
after using `~psycopg.Copy.set_types()`).
"""
if not (cls is None or isinstance(cls, (str, type))):
raise TypeError(
f"dumpers should be registered on classes, got {cls} instead"
)
if _psycopg:
dumper = self._get_optimised(dumper)
# Register the dumper both as its format and as auto
# so that the last dumper registered is used in auto (%s) format
if cls:
for fmt in (PyFormat.from_pq(dumper.format), PyFormat.AUTO):
if not self._own_dumpers[fmt]:
self._dumpers[fmt] = self._dumpers[fmt].copy()
self._own_dumpers[fmt] = True
self._dumpers[fmt][cls] = dumper
# Register the dumper by oid, if the oid of the dumper is fixed
if dumper.oid:
if not self._own_dumpers_by_oid[dumper.format]:
self._dumpers_by_oid[dumper.format] = self._dumpers_by_oid[
dumper.format
].copy()
self._own_dumpers_by_oid[dumper.format] = True
self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
def register_loader(self, oid: int | str, loader: type[Loader]) -> None:
"""
Configure the context to use `!loader` to convert data of oid `!oid`.
:param oid: The PostgreSQL OID or type name to manage.
:param loader: The loar to register for `!oid`.
If `oid` is specified as string, it refers to a type name, which is
looked up in the `types` registry.
"""
if isinstance(oid, str):
oid = self.types[oid].oid
if not isinstance(oid, int):
raise TypeError(f"loaders should be registered on oid, got {oid} instead")
if _psycopg:
loader = self._get_optimised(loader)
fmt = loader.format
if not self._own_loaders[fmt]:
self._loaders[fmt] = self._loaders[fmt].copy()
self._own_loaders[fmt] = True
self._loaders[fmt][oid] = loader
def get_dumper(self, cls: type, format: PyFormat) -> type[Dumper]:
"""
Return the dumper class for the given type and format.
Raise `~psycopg.ProgrammingError` if a class is not available.
:param cls: The class to adapt.
:param format: The format to dump to. If `~psycopg.adapt.PyFormat.AUTO`,
use the last one of the dumpers registered on `!cls`.
"""
try:
# Fast path: the class has a known dumper.
return self._dumpers[format][cls]
except KeyError:
if format not in self._dumpers:
raise ValueError(f"bad dumper format: {format}")
# If the KeyError was caused by cls missing from dmap, let's
# look for different cases.
dmap = self._dumpers[format]
# Look for the right class, including looking at superclasses
for scls in cls.__mro__:
if scls in dmap:
return dmap[scls]
# If the adapter is not found, look for its name as a string
fqn = scls.__module__ + "." + scls.__qualname__
if fqn in dmap:
# Replace the class name with the class itself
d = dmap[scls] = dmap.pop(fqn)
return d
format = PyFormat(format)
raise e.ProgrammingError(
f"cannot adapt type {cls.__name__!r} using placeholder '%{format.value}'"
f" (format: {format.name})"
)
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> type[Dumper]:
"""
Return the dumper class for the given oid and format.
Raise `~psycopg.ProgrammingError` if a class is not available.
:param oid: The oid of the type to dump to.
:param format: The format to dump to.
"""
try:
dmap = self._dumpers_by_oid[format]
except KeyError:
raise ValueError(f"bad dumper format: {format}")
try:
return dmap[oid]
except KeyError:
info = self.types.get(oid)
if info:
msg = (
f"cannot find a dumper for type {info.name} (oid {oid})"
f" format {pq.Format(format).name}"
)
else:
msg = (
f"cannot find a dumper for unknown type with oid {oid}"
f" format {pq.Format(format).name}"
)
raise e.ProgrammingError(msg)
def get_loader(self, oid: int, format: pq.Format) -> type[Loader] | None:
"""
Return the loader class for the given oid and format.
Return `!None` if not found.
:param oid: The oid of the type to load.
:param format: The format to load from.
"""
return self._loaders[format].get(oid)
@classmethod
def _get_optimised(self, cls: type[RV]) -> type[RV]:
"""Return the optimised version of a Dumper or Loader class.
Return the input class itself if there is no optimised version.
"""
try:
return self._optimised[cls]
except KeyError:
pass
# Check if the class comes from psycopg.types and there is a class
# with the same name in psycopg_c._psycopg.
from psycopg import types
if cls.__module__.startswith(types.__name__):
new = cast("type[RV]", getattr(_psycopg, cls.__name__, None))
if new:
self._optimised[cls] = new
return new
self._optimised[cls] = cls
return cls
# Micro-optimization: copying these objects is faster than creating new dicts
_dumpers_owned = dict.fromkeys(PyFormat, True)
_dumpers_shared = dict.fromkeys(PyFormat, False)

View File

@@ -0,0 +1,132 @@
"""
psycopg capabilities objects
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
from . import pq
from . import _cmodule
from .errors import NotSupportedError
class Capabilities:
"""
An object to check if a feature is supported by the libpq available on the client.
"""
def __init__(self) -> None:
self._cache: dict[str, str] = {}
def has_encrypt_password(self, check: bool = False) -> bool:
"""Check if the `PGconn.encrypt_password()` method is implemented.
The feature requires libpq 10.0 and greater.
"""
return self._has_feature("pq.PGconn.encrypt_password()", 100000, check=check)
def has_hostaddr(self, check: bool = False) -> bool:
"""Check if the `ConnectionInfo.hostaddr` attribute is implemented.
The feature requires libpq 12.0 and greater.
"""
return self._has_feature("Connection.info.hostaddr", 120000, check=check)
def has_pipeline(self, check: bool = False) -> bool:
"""Check if the :ref:`pipeline mode <pipeline-mode>` is supported.
The feature requires libpq 14.0 and greater.
"""
return self._has_feature("Connection.pipeline()", 140000, check=check)
def has_set_trace_flags(self, check: bool = False) -> bool:
"""Check if the `pq.PGconn.set_trace_flags()` method is implemented.
The feature requires libpq 14.0 and greater.
"""
return self._has_feature("PGconn.set_trace_flags()", 140000, check=check)
def has_cancel_safe(self, check: bool = False) -> bool:
"""Check if the `Connection.cancel_safe()` method is implemented.
The feature requires libpq 17.0 and greater.
"""
return self._has_feature("Connection.cancel_safe()", 170000, check=check)
def has_stream_chunked(self, check: bool = False) -> bool:
"""Check if `Cursor.stream()` can handle a `size` parameter value
greater than 1 to retrieve results by chunks.
The feature requires libpq 17.0 and greater.
"""
return self._has_feature(
"Cursor.stream() with 'size' parameter greater than 1", 170000, check=check
)
def has_send_close_prepared(self, check: bool = False) -> bool:
"""Check if the `pq.PGconn.send_closed_prepared()` method is implemented.
The feature requires libpq 17.0 and greater.
"""
return self._has_feature("PGconn.send_close_prepared()", 170000, check=check)
def _has_feature(self, feature: str, want_version: int, check: bool) -> bool:
"""
Check is a version is supported.
If `check` is true, raise an exception with an explicative message
explaining why the feature is not supported.
The expletive messages, are left to the user.
"""
if feature in self._cache:
msg = self._cache[feature]
else:
msg = self._get_unsupported_message(feature, want_version)
self._cache[feature] = msg
if not msg:
return True
elif check:
raise NotSupportedError(msg)
else:
return False
def _get_unsupported_message(self, feature: str, want_version: int) -> str:
"""
Return a descriptinve message to describe why a feature is unsupported.
Return an empty string if the feature is supported.
"""
if pq.version() < want_version:
return (
f"the feature '{feature}' is not available:"
f" the client libpq version (imported from {self._libpq_source()})"
f" is {pq.version_pretty(pq.version())}; the feature"
f" requires libpq version {pq.version_pretty(want_version)}"
" or newer"
)
elif pq.__build_version__ < want_version:
return (
f"the feature '{feature}' is not available:"
f" you are using a psycopg[{pq.__impl__}] libpq wrapper built"
f" with libpq version {pq.version_pretty(pq.__build_version__)};"
" the feature requires libpq version"
f" {pq.version_pretty(want_version)} or newer"
)
else:
return ""
def _libpq_source(self) -> str:
"""Return a string reporting where the libpq comes from."""
if pq.__impl__ == "binary":
version: str = _cmodule.__version__ or "unknown"
return f"the psycopg[binary] package version {version}"
else:
return "system libraries"
# The object that will be exposed by the module.
capabilities = Capabilities()

View File

@@ -0,0 +1,24 @@
"""
Simplify access to the _psycopg module
"""
# Copyright (C) 2021 The Psycopg Team
from __future__ import annotations
from . import pq
__version__: str | None = None
# Note: "c" must the first attempt so that mypy associates the variable the
# right module interface. It will not result Optional, but hey.
if pq.__impl__ == "c":
from psycopg_c import _psycopg as _psycopg
from psycopg_c import __version__ as __version__ # noqa: F401
elif pq.__impl__ == "binary":
from psycopg_binary import _psycopg as _psycopg # type: ignore
from psycopg_binary import __version__ as __version__ # type: ignore # noqa: F401
elif pq.__impl__ == "python":
_psycopg = None # type: ignore
else:
raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")

View File

@@ -0,0 +1,104 @@
"""
The Column object in Cursor.description
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from typing import Any, Sequence, TYPE_CHECKING
from operator import attrgetter
if TYPE_CHECKING:
from ._cursor_base import BaseCursor
class Column(Sequence[Any]):
__module__ = "psycopg"
def __init__(self, cursor: BaseCursor[Any, Any], index: int):
res = cursor.pgresult
assert res
fname = res.fname(index)
if fname:
self._name = fname.decode(cursor._encoding)
else:
# COPY_OUT results have columns but no name
self._name = f"column_{index + 1}"
self._ftype = res.ftype(index)
self._type = cursor.adapters.types.get(self._ftype)
self._fmod = res.fmod(index)
self._fsize = res.fsize(index)
_attrs = tuple(
attrgetter(attr)
for attr in """
name type_code display_size internal_size precision scale null_ok
""".split()
)
def __repr__(self) -> str:
return (
f"<Column {self.name!r},"
f" type: {self.type_display} (oid: {self.type_code})>"
)
def __len__(self) -> int:
return 7
@property
def type_display(self) -> str:
"""A pretty representation of the column type.
It is composed by the type name, followed by eventual modifiers and
brackets to signify arrays, e.g. :sql:`text`, :sql:`varchar(42)`,
:sql:`date[]`.
"""
if not self._type:
return str(self.type_code)
return self._type.get_type_display(oid=self.type_code, fmod=self._fmod)
def __getitem__(self, index: Any) -> Any:
if isinstance(index, slice):
return tuple(getter(self) for getter in self._attrs[index])
else:
return self._attrs[index](self)
@property
def name(self) -> str:
"""The name of the column."""
return self._name
@property
def type_code(self) -> int:
"""The numeric OID of the column."""
return self._ftype
@property
def display_size(self) -> int | None:
"""The field size, for string types such as :sql:`varchar(n)`."""
return self._type.get_display_size(self._fmod) if self._type else None
@property
def internal_size(self) -> int | None:
"""The internal field size for fixed-size types, None otherwise."""
fsize = self._fsize
return fsize if fsize >= 0 else None
@property
def precision(self) -> int | None:
"""The number of digits for fixed precision types."""
return self._type.get_precision(self._fmod) if self._type else None
@property
def scale(self) -> int | None:
"""The number of digits after the decimal point if available."""
return self._type.get_scale(self._fmod) if self._type else None
@property
def null_ok(self) -> bool | None:
"""Always `!None`"""
return None

View File

@@ -0,0 +1,59 @@
"""
compatibility functions for different Python versions
"""
# Copyright (C) 2021 The Psycopg Team
import sys
from functools import partial
from typing import Any
if sys.version_info >= (3, 9):
from asyncio import to_thread
from zoneinfo import ZoneInfo
from functools import cache
from collections import Counter, deque as Deque
from collections.abc import Callable
else:
import asyncio
from typing import Callable, Counter, Deque, TypeVar
from functools import lru_cache
from backports.zoneinfo import ZoneInfo
cache = lru_cache(maxsize=None)
R = TypeVar("R")
async def to_thread(func: Callable[..., R], /, *args: Any, **kwargs: Any) -> R:
loop = asyncio.get_running_loop()
func_call = partial(func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
if sys.version_info >= (3, 10):
from typing import TypeGuard, TypeAlias
else:
from typing_extensions import TypeGuard, TypeAlias
if sys.version_info >= (3, 11):
from typing import LiteralString, Self
else:
from typing_extensions import LiteralString, Self
if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar
__all__ = [
"Counter",
"Deque",
"LiteralString",
"Self",
"TypeAlias",
"TypeGuard",
"TypeVar",
"ZoneInfo",
"cache",
"to_thread",
]

View File

@@ -0,0 +1,691 @@
"""
psycopg connection objects
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import sys
import logging
from typing import Callable, Generic, NamedTuple, TYPE_CHECKING
from weakref import ref, ReferenceType
from warnings import warn
from functools import partial
from . import pq
from . import errors as e
from . import postgres
from . import generators
from .abc import PQGen, PQGenConn, Query
from .sql import Composable, SQL
from ._tpc import Xid
from .rows import Row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import LiteralString, Self, TypeAlias, TypeVar
from .pq.misc import connection_summary
from ._pipeline import BasePipeline
from ._preparing import PrepareManager
from ._capabilities import capabilities
from ._connection_info import ConnectionInfo
if TYPE_CHECKING:
from .pq.abc import PGconn, PGresult
from psycopg_pool.base import BasePool
# Row Type variable for Cursor (when it needs to be distinguished from the
# connection's one)
CursorRow = TypeVar("CursorRow")
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
OK = pq.ConnStatus.OK
BAD = pq.ConnStatus.BAD
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
_HAS_SEND_CLOSE = capabilities.has_send_close_prepared()
logger = logging.getLogger("psycopg")
class Notify(NamedTuple):
"""An asynchronous notification received from the database."""
channel: str
"""The name of the channel on which the notification was received."""
payload: str
"""The message attached to the notification."""
pid: int
"""The PID of the backend process which sent the notification."""
Notify.__module__ = "psycopg"
NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None]
NotifyHandler: TypeAlias = Callable[[Notify], None]
class BaseConnection(Generic[Row]):
"""
Base class for different types of connections.
Share common functionalities such as access to the wrapped PGconn, but
allow different interfaces (sync/async).
"""
# DBAPI2 exposed exceptions
Warning = e.Warning
Error = e.Error
InterfaceError = e.InterfaceError
DatabaseError = e.DatabaseError
DataError = e.DataError
OperationalError = e.OperationalError
IntegrityError = e.IntegrityError
InternalError = e.InternalError
ProgrammingError = e.ProgrammingError
NotSupportedError = e.NotSupportedError
def __init__(self, pgconn: PGconn):
self.pgconn = pgconn
self._autocommit = False
# None, but set to a copy of the global adapters map as soon as requested.
self._adapters: AdaptersMap | None = None
self._notice_handlers: list[NoticeHandler] = []
self._notify_handlers: list[NotifyHandler] = []
# Number of transaction blocks currently entered
self._num_transactions = 0
self._closed = False # closed by an explicit close()
self._prepared: PrepareManager = PrepareManager()
self._tpc: tuple[Xid, bool] | None = None # xid, prepared
wself = ref(self)
pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
# Attribute is only set if the connection is from a pool so we can tell
# apart a connection in the pool too (when _pool = None)
self._pool: BasePool | None
self._pipeline: BasePipeline | None = None
# Time after which the connection should be closed
self._expire_at: float
self._isolation_level: IsolationLevel | None = None
self._read_only: bool | None = None
self._deferrable: bool | None = None
self._begin_statement = b""
def __del__(self) -> None:
# If fails on connection we might not have this attribute yet
if not hasattr(self, "pgconn"):
return
# Connection correctly closed
if self.closed:
return
# Connection in a pool so terminating with the program is normal
if hasattr(self, "_pool"):
return
warn(
f"connection {self} was deleted while still open."
" Please use 'with' or '.close()' to close the connection",
ResourceWarning,
)
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self.pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
@property
def closed(self) -> bool:
"""`!True` if the connection is closed."""
return self.pgconn.status == BAD
@property
def broken(self) -> bool:
"""
`!True` if the connection was interrupted.
A broken connection is always `closed`, but wasn't closed in a clean
way, such as using `close()` or a `!with` block.
"""
return self.pgconn.status == BAD and not self._closed
@property
def autocommit(self) -> bool:
"""The autocommit state of the connection."""
return self._autocommit
@autocommit.setter
def autocommit(self, value: bool) -> None:
self._set_autocommit(value)
def _set_autocommit(self, value: bool) -> None:
raise NotImplementedError
def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
yield from self._check_intrans_gen("autocommit")
self._autocommit = bool(value)
@property
def isolation_level(self) -> IsolationLevel | None:
"""
The isolation level of the new transactions started on the connection.
"""
return self._isolation_level
@isolation_level.setter
def isolation_level(self, value: IsolationLevel | None) -> None:
self._set_isolation_level(value)
def _set_isolation_level(self, value: IsolationLevel | None) -> None:
raise NotImplementedError
def _set_isolation_level_gen(self, value: IsolationLevel | None) -> PQGen[None]:
yield from self._check_intrans_gen("isolation_level")
self._isolation_level = IsolationLevel(value) if value is not None else None
self._begin_statement = b""
@property
def read_only(self) -> bool | None:
"""
The read-only state of the new transactions started on the connection.
"""
return self._read_only
@read_only.setter
def read_only(self, value: bool | None) -> None:
self._set_read_only(value)
def _set_read_only(self, value: bool | None) -> None:
raise NotImplementedError
def _set_read_only_gen(self, value: bool | None) -> PQGen[None]:
yield from self._check_intrans_gen("read_only")
self._read_only = bool(value) if value is not None else None
self._begin_statement = b""
@property
def deferrable(self) -> bool | None:
"""
The deferrable state of the new transactions started on the connection.
"""
return self._deferrable
@deferrable.setter
def deferrable(self, value: bool | None) -> None:
self._set_deferrable(value)
def _set_deferrable(self, value: bool | None) -> None:
raise NotImplementedError
def _set_deferrable_gen(self, value: bool | None) -> PQGen[None]:
yield from self._check_intrans_gen("deferrable")
self._deferrable = bool(value) if value is not None else None
self._begin_statement = b""
def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
# Raise an exception if we are in a transaction
status = self.pgconn.transaction_status
if status == IDLE and self._pipeline:
yield from self._pipeline._sync_gen()
status = self.pgconn.transaction_status
if status != IDLE:
if self._num_transactions:
raise e.ProgrammingError(
f"can't change {attribute!r} now: "
"connection.transaction() context in progress"
)
else:
raise e.ProgrammingError(
f"can't change {attribute!r} now: "
"connection in transaction status "
f"{pq.TransactionStatus(status).name}"
)
@property
def info(self) -> ConnectionInfo:
"""A `ConnectionInfo` attribute to inspect connection properties."""
return ConnectionInfo(self.pgconn)
@property
def adapters(self) -> AdaptersMap:
if not self._adapters:
self._adapters = AdaptersMap(postgres.adapters)
return self._adapters
@property
def connection(self) -> BaseConnection[Row]:
# implement the AdaptContext protocol
return self
def fileno(self) -> int:
"""Return the file descriptor of the connection.
This function allows to use the connection as file-like object in
functions waiting for readiness, such as the ones defined in the
`selectors` module.
"""
return self.pgconn.socket
def cancel(self) -> None:
"""Cancel the current operation on the connection."""
if self._should_cancel():
c = self.pgconn.get_cancel()
c.cancel()
def _should_cancel(self) -> bool:
"""Check whether the current command should actually be cancelled when
invoking cancel*().
"""
# cancel() is a no-op if the connection is closed;
# this allows to use the method as callback handler without caring
# about its life.
if self.closed:
return False
if self._tpc and self._tpc[1]:
raise e.ProgrammingError(
"cancel() cannot be used with a prepared two-phase transaction"
)
return True
def _cancel_gen(self, *, timeout: float) -> PQGenConn[None]:
cancel_conn = self.pgconn.cancel_conn()
cancel_conn.start()
yield from generators.cancel(cancel_conn, timeout=timeout)
def add_notice_handler(self, callback: NoticeHandler) -> None:
"""
Register a callable to be invoked when a notice message is received.
:param callback: the callback to call upon message received.
:type callback: Callable[[~psycopg.errors.Diagnostic], None]
"""
self._notice_handlers.append(callback)
def remove_notice_handler(self, callback: NoticeHandler) -> None:
"""
Unregister a notice message callable previously registered.
:param callback: the callback to remove.
:type callback: Callable[[~psycopg.errors.Diagnostic], None]
"""
self._notice_handlers.remove(callback)
@staticmethod
def _notice_handler(
wself: ReferenceType[BaseConnection[Row]], res: PGresult
) -> None:
self = wself()
if not (self and self._notice_handlers):
return
diag = e.Diagnostic(res, self.pgconn._encoding)
for cb in self._notice_handlers:
try:
cb(diag)
except Exception as ex:
logger.exception("error processing notice callback '%s': %s", cb, ex)
def add_notify_handler(self, callback: NotifyHandler) -> None:
"""
Register a callable to be invoked whenever a notification is received.
:param callback: the callback to call upon notification received.
:type callback: Callable[[~psycopg.Notify], None]
"""
self._notify_handlers.append(callback)
def remove_notify_handler(self, callback: NotifyHandler) -> None:
"""
Unregister a notification callable previously registered.
:param callback: the callback to remove.
:type callback: Callable[[~psycopg.Notify], None]
"""
self._notify_handlers.remove(callback)
@staticmethod
def _notify_handler(
wself: ReferenceType[BaseConnection[Row]], pgn: pq.PGnotify
) -> None:
self = wself()
if not (self and self._notify_handlers):
return
enc = self.pgconn._encoding
n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
for cb in self._notify_handlers:
cb(n)
@property
def prepare_threshold(self) -> int | None:
"""
Number of times a query is executed before it is prepared.
- If it is set to 0, every query is prepared the first time it is
executed.
- If it is set to `!None`, prepared statements are disabled on the
connection.
Default value: 5
"""
return self._prepared.prepare_threshold
@prepare_threshold.setter
def prepare_threshold(self, value: int | None) -> None:
self._prepared.prepare_threshold = value
@property
def prepared_max(self) -> int | None:
"""
Maximum number of prepared statements on the connection.
`!None` means no max number of prepared statements. The default value
is 100.
"""
rv = self._prepared.prepared_max
return rv if rv != sys.maxsize else None
@prepared_max.setter
def prepared_max(self, value: int | None) -> None:
if value is None:
value = sys.maxsize
self._prepared.prepared_max = value
# Generators to perform high-level operations on the connection
#
# These operations are expressed in terms of non-blocking generators
# and the task of waiting when needed (when the generators yield) is left
# to the connections subclass, which might wait either in blocking mode
# or through asyncio.
#
# All these generators assume exclusive access to the connection: subclasses
# should have a lock and hold it before calling and consuming them.
@classmethod
def _connect_gen(
cls, conninfo: str = "", *, timeout: float = 0.0
) -> PQGenConn[Self]:
"""Generator to connect to the database and create a new instance."""
pgconn = yield from generators.connect(conninfo, timeout=timeout)
conn = cls(pgconn)
return conn
def _exec_command(
self, command: Query, result_format: pq.Format = TEXT
) -> PQGen[PGresult | None]:
"""
Generator to send a command and receive the result to the backend.
Only used to implement internal commands such as "commit", with eventual
arguments bound client-side. The cursor can do more complex stuff.
"""
self._check_connection_ok()
if isinstance(command, str):
command = command.encode(self.pgconn._encoding)
elif isinstance(command, Composable):
command = command.as_bytes(self)
if self._pipeline:
cmd = partial(
self.pgconn.send_query_params,
command,
None,
result_format=result_format,
)
self._pipeline.command_queue.append(cmd)
self._pipeline.result_queue.append(None)
return None
# Unless needed, use the simple query protocol, e.g. to interact with
# pgbouncer. In pipeline mode we always use the advanced query protocol
# instead, see #350
if result_format == TEXT:
self.pgconn.send_query(command)
else:
self.pgconn.send_query_params(command, None, result_format=result_format)
result = (yield from generators.execute(self.pgconn))[-1]
if result.status != COMMAND_OK and result.status != TUPLES_OK:
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self.pgconn._encoding)
else:
raise e.InterfaceError(
f"unexpected result {pq.ExecStatus(result.status).name}"
f" from command {command.decode()!r}"
)
return result
def _deallocate(self, name: bytes | None) -> PQGen[None]:
"""
Deallocate one, or all, prepared statement in the session.
``name == None`` stands for DEALLOCATE ALL.
If possible, use protocol-level commands; otherwise use SQL statements.
Note that PgBouncer doesn't support DEALLOCATE name, but it supports
protocol-level Close from 1.21 and DEALLOCATE ALL from 1.22.
"""
if name is None or not _HAS_SEND_CLOSE:
stmt = b"DEALLOCATE " + name if name is not None else b"DEALLOCATE ALL"
yield from self._exec_command(stmt)
return
self._check_connection_ok()
if self._pipeline:
cmd = partial(
self.pgconn.send_close_prepared,
name,
)
self._pipeline.command_queue.append(cmd)
self._pipeline.result_queue.append(None)
return
self.pgconn.send_close_prepared(name)
result = (yield from generators.execute(self.pgconn))[-1]
if result.status != COMMAND_OK:
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self.pgconn._encoding)
else:
raise e.InterfaceError(
f"unexpected result {pq.ExecStatus(result.status).name}"
" from sending closing prepared statement message"
)
def _check_connection_ok(self) -> None:
if self.pgconn.status == OK:
return
if self.pgconn.status == BAD:
raise e.OperationalError("the connection is closed")
raise e.InterfaceError(
"cannot execute operations: the connection is"
f" in status {self.pgconn.status}"
)
def _start_query(self) -> PQGen[None]:
"""Generator to start a transaction if necessary."""
if self._autocommit:
return
if self.pgconn.transaction_status != IDLE:
return
yield from self._exec_command(self._get_tx_start_command())
if self._pipeline:
yield from self._pipeline._sync_gen()
def _get_tx_start_command(self) -> bytes:
if self._begin_statement:
return self._begin_statement
parts = [b"BEGIN"]
if self.isolation_level is not None:
val = IsolationLevel(self.isolation_level)
parts.append(b"ISOLATION LEVEL")
parts.append(val.name.replace("_", " ").encode())
if self.read_only is not None:
parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
if self.deferrable is not None:
parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
self._begin_statement = b" ".join(parts)
return self._begin_statement
def _commit_gen(self) -> PQGen[None]:
"""Generator implementing `Connection.commit()`."""
if self._num_transactions:
raise e.ProgrammingError(
"Explicit commit() forbidden within a Transaction "
"context. (Transaction will be automatically committed "
"on successful exit from context.)"
)
if self._tpc:
raise e.ProgrammingError(
"commit() cannot be used during a two-phase transaction"
)
if self.pgconn.transaction_status == IDLE:
return
yield from self._exec_command(b"COMMIT")
if self._pipeline:
yield from self._pipeline._sync_gen()
def _rollback_gen(self) -> PQGen[None]:
"""Generator implementing `Connection.rollback()`."""
if self._num_transactions:
raise e.ProgrammingError(
"Explicit rollback() forbidden within a Transaction "
"context. (Either raise Rollback() or allow "
"an exception to propagate out of the context.)"
)
if self._tpc:
raise e.ProgrammingError(
"rollback() cannot be used during a two-phase transaction"
)
# Get out of a "pipeline aborted" state
if self._pipeline:
yield from self._pipeline._sync_gen()
if self.pgconn.transaction_status == IDLE:
return
yield from self._exec_command(b"ROLLBACK")
self._prepared.clear()
yield from self._prepared.maintain_gen(self)
if self._pipeline:
yield from self._pipeline._sync_gen()
def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid:
"""
Returns a `Xid` to pass to the `!tpc_*()` methods of this connection.
The argument types and constraints are explained in
:ref:`two-phase-commit`.
The values passed to the method will be available on the returned
object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`.
"""
self._check_tpc()
return Xid.from_parts(format_id, gtrid, bqual)
def _tpc_begin_gen(self, xid: Xid | str) -> PQGen[None]:
self._check_tpc()
if not isinstance(xid, Xid):
xid = Xid.from_string(xid)
if self.pgconn.transaction_status != IDLE:
raise e.ProgrammingError(
"can't start two-phase transaction: connection in status"
f" {pq.TransactionStatus(self.pgconn.transaction_status).name}"
)
if self._autocommit:
raise e.ProgrammingError(
"can't use two-phase transactions in autocommit mode"
)
self._tpc = (xid, False)
yield from self._exec_command(self._get_tx_start_command())
def _tpc_prepare_gen(self) -> PQGen[None]:
if not self._tpc:
raise e.ProgrammingError(
"'tpc_prepare()' must be called inside a two-phase transaction"
)
if self._tpc[1]:
raise e.ProgrammingError(
"'tpc_prepare()' cannot be used during a prepared two-phase transaction"
)
xid = self._tpc[0]
self._tpc = (xid, True)
yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
if self._pipeline:
yield from self._pipeline._sync_gen()
def _tpc_finish_gen(
self, action: LiteralString, xid: Xid | str | None
) -> PQGen[None]:
fname = f"tpc_{action.lower()}()"
if xid is None:
if not self._tpc:
raise e.ProgrammingError(
f"{fname} without xid must must be"
" called inside a two-phase transaction"
)
xid = self._tpc[0]
else:
if self._tpc:
raise e.ProgrammingError(
f"{fname} with xid must must be called"
" outside a two-phase transaction"
)
if not isinstance(xid, Xid):
xid = Xid.from_string(xid)
if self._tpc and not self._tpc[1]:
meth: Callable[[], PQGen[None]]
meth = getattr(self, f"_{action.lower()}_gen")
self._tpc = None
yield from meth()
else:
yield from self._exec_command(
SQL("{} PREPARED {}").format(SQL(action), str(xid))
)
self._tpc = None
def _check_tpc(self) -> None:
"""Raise NotSupportedError if TPC is not supported."""
# TPC supported on every supported PostgreSQL version.
pass

View File

@@ -0,0 +1,173 @@
"""
Objects to return information about a PostgreSQL connection.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from pathlib import Path
from datetime import tzinfo
from . import pq
from ._tz import get_tzinfo
from .conninfo import make_conninfo
class ConnectionInfo:
"""Allow access to information about the connection."""
__module__ = "psycopg"
def __init__(self, pgconn: pq.abc.PGconn):
self.pgconn = pgconn
@property
def vendor(self) -> str:
"""A string representing the database vendor connected to."""
return "PostgreSQL"
@property
def host(self) -> str:
"""The server host name of the active connection. See :pq:`PQhost()`."""
return self._get_pgconn_attr("host")
@property
def hostaddr(self) -> str:
"""The server IP address of the connection. See :pq:`PQhostaddr()`."""
return self._get_pgconn_attr("hostaddr")
@property
def port(self) -> int:
"""The port of the active connection. See :pq:`PQport()`."""
return int(self._get_pgconn_attr("port"))
@property
def dbname(self) -> str:
"""The database name of the connection. See :pq:`PQdb()`."""
return self._get_pgconn_attr("db")
@property
def user(self) -> str:
"""The user name of the connection. See :pq:`PQuser()`."""
return self._get_pgconn_attr("user")
@property
def password(self) -> str:
"""The password of the connection. See :pq:`PQpass()`."""
return self._get_pgconn_attr("password")
@property
def options(self) -> str:
"""
The command-line options passed in the connection request.
See :pq:`PQoptions`.
"""
return self._get_pgconn_attr("options")
def get_parameters(self) -> dict[str, str]:
"""Return the connection parameters values.
Return all the parameters set to a non-default value, which might come
either from the connection string and parameters passed to
`~Connection.connect()` or from environment variables. The password
is never returned (you can read it using the `password` attribute).
"""
pyenc = self.encoding
# Get the known defaults to avoid reporting them
defaults = {
i.keyword: i.compiled
for i in pq.Conninfo.get_defaults()
if i.compiled is not None
}
# Not returned by the libq. Bug? Bet we're using SSH.
defaults.setdefault(b"channel_binding", b"prefer")
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
return {
i.keyword.decode(pyenc): i.val.decode(pyenc)
for i in self.pgconn.info
if i.val is not None
and i.keyword != b"password"
and i.val != defaults.get(i.keyword)
}
@property
def dsn(self) -> str:
"""Return the connection string to connect to the database.
The string contains all the parameters set to a non-default value,
which might come either from the connection string and parameters
passed to `~Connection.connect()` or from environment variables. The
password is never returned (you can read it using the `password`
attribute).
"""
return make_conninfo(**self.get_parameters())
@property
def status(self) -> pq.ConnStatus:
"""The status of the connection. See :pq:`PQstatus()`."""
return pq.ConnStatus(self.pgconn.status)
@property
def transaction_status(self) -> pq.TransactionStatus:
"""
The current in-transaction status of the session.
See :pq:`PQtransactionStatus()`.
"""
return pq.TransactionStatus(self.pgconn.transaction_status)
@property
def pipeline_status(self) -> pq.PipelineStatus:
"""
The current pipeline status of the client.
See :pq:`PQpipelineStatus()`.
"""
return pq.PipelineStatus(self.pgconn.pipeline_status)
def parameter_status(self, param_name: str) -> str | None:
"""
Return a parameter setting of the connection.
Return `None` is the parameter is unknown.
"""
res = self.pgconn.parameter_status(param_name.encode(self.encoding))
return res.decode(self.encoding) if res is not None else None
@property
def server_version(self) -> int:
"""
An integer representing the server version. See :pq:`PQserverVersion()`.
"""
return self.pgconn.server_version
@property
def backend_pid(self) -> int:
"""
The process ID (PID) of the backend process handling this connection.
See :pq:`PQbackendPID()`.
"""
return self.pgconn.backend_pid
@property
def error_message(self) -> str:
"""
The error message most recently generated by an operation on the connection.
See :pq:`PQerrorMessage()`.
"""
return self._get_pgconn_attr("error_message")
@property
def timezone(self) -> tzinfo:
"""The Python timezone info of the connection's timezone."""
return get_tzinfo(self.pgconn)
@property
def encoding(self) -> str:
"""The Python codec name of the connection's client encoding."""
return self.pgconn._encoding
def _get_pgconn_attr(self, name: str) -> str:
value: bytes = getattr(self.pgconn, name)
return value.decode(self.encoding)

View File

@@ -0,0 +1,96 @@
# WARNING: this file is auto-generated by 'async_to_sync.py'
# from the original file '_conninfo_attempts_async.py'
# DO NOT CHANGE! Change the original file instead.
"""
Separate connection attempts from a connection string.
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
import socket
import logging
from random import shuffle
from . import errors as e
from .abc import ConnDict, ConnMapping
from ._conninfo_utils import get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts
logger = logging.getLogger("psycopg")
def conninfo_attempts(params: ConnMapping) -> list[ConnDict]:
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Also perform async resolution of the hostname into hostaddr. Because a host
can resolve to more than one address, this can lead to yield more attempts
too. Raise `OperationalError` if no host could be resolved.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
last_exc = None
attempts = []
for attempt in split_attempts(params):
try:
attempts.extend(_resolve_hostnames(attempt))
except OSError as ex:
logger.debug("failed to resolve host %r: %s", attempt.get("host"), ex)
last_exc = ex
if not attempts:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
if get_param(params, "load_balance_hosts") == "random":
shuffle(attempts)
return attempts
def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
"""
Perform async DNS lookup of the hosts and return a list of connection attempts.
If a ``host`` param is present but not ``hostname``, resolve the host
addresses asynchronously.
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr because it is designed to further
process the input of split_attempts().
:return: A list of attempts to make (to include the case of a hostname
resolving to more than one IP).
"""
host = get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
return [params]
hostaddr = get_param(params, "hostaddr")
if hostaddr:
# Already resolved
return [params]
if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
return [{**params, "hostaddr": host}]
port = get_param(params, "port")
if not port:
port_def = get_param_def("port")
port = port_def and port_def.compiled or "5432"
ans = socket.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
return [{**params, "hostaddr": item[4][0]} for item in ans]

View File

@@ -0,0 +1,101 @@
"""
Separate connection attempts from a connection string.
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
import socket
import logging
from random import shuffle
from . import errors as e
from .abc import ConnDict, ConnMapping
from ._conninfo_utils import get_param, is_ip_address, get_param_def
from ._conninfo_utils import split_attempts
if True: # ASYNC:
import asyncio
logger = logging.getLogger("psycopg")
async def conninfo_attempts_async(params: ConnMapping) -> list[ConnDict]:
"""Split a set of connection params on the single attempts to perform.
A connection param can perform more than one attempt more than one ``host``
is provided.
Also perform async resolution of the hostname into hostaddr. Because a host
can resolve to more than one address, this can lead to yield more attempts
too. Raise `OperationalError` if no host could be resolved.
Because the libpq async function doesn't honour the timeout, we need to
reimplement the repeated attempts.
"""
last_exc = None
attempts = []
for attempt in split_attempts(params):
try:
attempts.extend(await _resolve_hostnames(attempt))
except OSError as ex:
logger.debug("failed to resolve host %r: %s", attempt.get("host"), ex)
last_exc = ex
if not attempts:
assert last_exc
# We couldn't resolve anything
raise e.OperationalError(str(last_exc))
if get_param(params, "load_balance_hosts") == "random":
shuffle(attempts)
return attempts
async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
"""
Perform async DNS lookup of the hosts and return a list of connection attempts.
If a ``host`` param is present but not ``hostname``, resolve the host
addresses asynchronously.
:param params: The input parameters, for instance as returned by
`~psycopg.conninfo.conninfo_to_dict()`. The function expects at most
a single entry for host, hostaddr because it is designed to further
process the input of split_attempts().
:return: A list of attempts to make (to include the case of a hostname
resolving to more than one IP).
"""
host = get_param(params, "host")
if not host or host.startswith("/") or host[1:2] == ":":
# Local path, or no host to resolve
return [params]
hostaddr = get_param(params, "hostaddr")
if hostaddr:
# Already resolved
return [params]
if is_ip_address(host):
# If the host is already an ip address don't try to resolve it
return [{**params, "hostaddr": host}]
port = get_param(params, "port")
if not port:
port_def = get_param_def("port")
port = port_def and port_def.compiled or "5432"
if True: # ASYNC:
loop = asyncio.get_running_loop()
ans = await loop.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
else:
ans = socket.getaddrinfo(
host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM
)
return [{**params, "hostaddr": item[4][0]} for item in ans]

View File

@@ -0,0 +1,124 @@
"""
Internal utilities to manipulate connection strings
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
import os
from functools import lru_cache
from ipaddress import ip_address
from dataclasses import dataclass
from . import pq
from .abc import ConnDict, ConnMapping
from . import errors as e
def split_attempts(params: ConnMapping) -> list[ConnDict]:
"""
Split connection parameters with a sequence of hosts into separate attempts.
"""
def split_val(key: str) -> list[str]:
val = get_param(params, key)
return val.split(",") if val else []
hosts = split_val("host")
hostaddrs = split_val("hostaddr")
ports = split_val("port")
if hosts and hostaddrs and len(hosts) != len(hostaddrs):
raise e.OperationalError(
f"could not match {len(hosts)} host names"
f" with {len(hostaddrs)} hostaddr values"
)
nhosts = max(len(hosts), len(hostaddrs))
if 1 < len(ports) != nhosts:
raise e.OperationalError(
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
)
# A single attempt to make. Don't mangle the conninfo string.
if nhosts <= 1:
return [{**params}]
if len(ports) == 1:
ports *= nhosts
# Now all lists are either empty or have the same length
rv = []
for i in range(nhosts):
attempt = {**params}
if hosts:
attempt["host"] = hosts[i]
if hostaddrs:
attempt["hostaddr"] = hostaddrs[i]
if ports:
attempt["port"] = ports[i]
rv.append(attempt)
return rv
def get_param(params: ConnMapping, name: str) -> str | None:
"""
Return a value from a connection string.
The value may be also specified in a PG* env var.
"""
if name in params:
return str(params[name])
# TODO: check if in service
paramdef = get_param_def(name)
if not paramdef:
return None
env = os.environ.get(paramdef.envvar)
if env is not None:
return env
return None
@dataclass
class ParamDef:
"""
Information about defaults and env vars for connection params
"""
keyword: str
envvar: str
compiled: str | None
def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
"""
Return the ParamDef of a connection string parameter.
"""
if not _cache:
defs = pq.Conninfo.get_defaults()
for d in defs:
cd = ParamDef(
keyword=d.keyword.decode(),
envvar=d.envvar.decode() if d.envvar else "",
compiled=d.compiled.decode() if d.compiled is not None else None,
)
_cache[cd.keyword] = cd
return _cache.get(keyword)
@lru_cache
def is_ip_address(s: str) -> bool:
"""Return True if the string represent a valid ip address."""
try:
ip_address(s)
except ValueError:
return False
return True

View File

@@ -0,0 +1,298 @@
# WARNING: this file is auto-generated by 'async_to_sync.py'
# from the original file '_copy_async.py'
# DO NOT CHANGE! Change the original file instead.
"""
Objects to support the COPY protocol (sync version).
"""
# Copyright (C) 2023 The Psycopg Team
from __future__ import annotations
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, Iterator, Sequence, TYPE_CHECKING
from . import pq
from . import errors as e
from ._compat import Self
from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE, PREFER_FLUSH
from .generators import copy_to, copy_end
from ._acompat import spawn, gather, Queue, Worker
if TYPE_CHECKING:
from .abc import Buffer
from .cursor import Cursor
from .connection import Connection # noqa: F401
COPY_IN = pq.ExecStatus.COPY_IN
COPY_OUT = pq.ExecStatus.COPY_OUT
ACTIVE = pq.TransactionStatus.ACTIVE
class Copy(BaseCopy["Connection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation.
:param cursor: the cursor where the operation is performed.
:param binary: if `!True`, write binary format.
:param writer: the object to write to destination. If not specified, write
to the `!cursor` connection.
Choosing `!binary` is not necessary if the cursor has executed a
:sql:`COPY` operation, because the operation result describes the format
too. The parameter is useful when a `!Copy` object is created manually and
no operation is performed on the cursor, such as when using ``writer=``\\
`~psycopg.copy.FileWriter`.
"""
__module__ = "psycopg"
writer: Writer
def __init__(
self,
cursor: Cursor[Any],
*,
binary: bool | None = None,
writer: Writer | None = None,
):
super().__init__(cursor, binary=binary)
if not writer:
writer = LibpqWriter(cursor)
self.writer = writer
self._write = writer.write
def __enter__(self) -> Self:
self._enter()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.finish(exc_val)
# End user sync interface
def __iter__(self) -> Iterator[Buffer]:
"""Implement block-by-block iteration on :sql:`COPY TO`."""
while True:
data = self.read()
if not data:
break
yield data
def read(self) -> Buffer:
"""
Read an unparsed row after a :sql:`COPY TO` operation.
Return an empty string when the data is finished.
"""
return self.connection.wait(self._read_gen())
def rows(self) -> Iterator[tuple[Any, ...]]:
"""
Iterate on the result of a :sql:`COPY TO` operation record by record.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
while True:
record = self.read_row()
if record is None:
break
yield record
def read_row(self) -> tuple[Any, ...] | None:
"""
Read a parsed row of data from a table after a :sql:`COPY TO` operation.
Return `!None` when the data is finished.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
return self.connection.wait(self._read_row_gen())
def write(self, buffer: Buffer | str) -> None:
"""
Write a block of data to a table after a :sql:`COPY FROM` operation.
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
data = self.formatter.write(buffer)
if data:
self._write(data)
def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
data = self.formatter.write_row(row)
if data:
self._write(data)
def finish(self, exc: BaseException | None) -> None:
"""Terminate the copy operation and free the resources allocated.
You shouldn't need to call this function yourself: it is usually called
by exit. It is available if, despite what is documented, you end up
using the `Copy` object outside a block.
"""
if self._direction == COPY_IN:
data = self.formatter.end()
if data:
self._write(data)
self.writer.finish(exc)
self._finished = True
else:
if not exc:
return
if self._pgconn.transaction_status != ACTIVE:
# The server has already finished to send copy data. The connection
# is already in a good state.
return
# Throw a cancel to the server, then consume the rest of the copy data
# (which might or might not have been already transferred entirely to
# the client, so we won't necessary see the exception associated with
# canceling).
self.connection._try_cancel()
self.connection.wait(self._end_copy_out_gen())
class Writer(ABC):
"""
A class to write copy data somewhere (for async connections).
"""
@abstractmethod
def write(self, data: Buffer) -> None:
"""Write some data to destination."""
...
def finish(self, exc: BaseException | None = None) -> None:
"""
Called when write operations are finished.
If operations finished with an error, it will be passed to ``exc``.
"""
pass
class LibpqWriter(Writer):
"""
An `Writer` to write copy data to a Postgres database.
"""
__module__ = "psycopg.copy"
def __init__(self, cursor: Cursor[Any]):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
self.connection.wait(
copy_to(
self._pgconn, data[i : i + MAX_BUFFER_SIZE], flush=PREFER_FLUSH
)
)
def finish(self, exc: BaseException | None = None) -> None:
bmsg: bytes | None
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
bmsg = msg.encode(self._pgconn._encoding, "replace")
else:
bmsg = None
try:
res = self.connection.wait(copy_end(self._pgconn, bmsg))
# The QueryCanceled is expected if we sent an exception message to
# pgconn.put_copy_end(). The Python exception that generated that
# cancelling is more important, so don't clobber it.
except e.QueryCanceled:
if not bmsg:
raise
else:
self.cursor._results = [res]
class QueuedLibpqWriter(LibpqWriter):
"""
`Writer` using a buffer to queue data to write.
`write()` returns immediately, so that the main thread can be CPU-bound
formatting messages, while a worker thread can be IO-bound waiting to write
on the connection.
"""
__module__ = "psycopg.copy"
def __init__(self, cursor: Cursor[Any]):
super().__init__(cursor)
self._queue: Queue[Buffer] = Queue(maxsize=QUEUE_SIZE)
self._worker: Worker | None = None
self._worker_error: BaseException | None = None
def worker(self) -> None:
"""Push data to the server when available from the copy queue.
Terminate reading when the queue receives a false-y value, or in case
of error.
The function is designed to be run in a separate task.
"""
try:
while True:
data = self._queue.get()
if not data:
break
self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH))
except BaseException as ex:
# Propagate the error to the main thread.
self._worker_error = ex
def write(self, data: Buffer) -> None:
if not self._worker:
# warning: reference loop, broken by _write_end
self._worker = spawn(self.worker)
# If the worker thread raies an exception, re-raise it to the caller.
if self._worker_error:
raise self._worker_error
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
self._queue.put(data[i : i + MAX_BUFFER_SIZE])
def finish(self, exc: BaseException | None = None) -> None:
self._queue.put(b"")
if self._worker:
gather(self._worker)
self._worker = None # break reference loops if any
# Check if the worker thread raised any exception before terminating.
if self._worker_error:
raise self._worker_error
super().finish(exc)

View File

@@ -0,0 +1,299 @@
"""
Objects to support the COPY protocol (async version).
"""
# Copyright (C) 2023 The Psycopg Team
from __future__ import annotations
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, AsyncIterator, Sequence, TYPE_CHECKING
from . import pq
from . import errors as e
from ._compat import Self
from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE, PREFER_FLUSH
from .generators import copy_to, copy_end
from ._acompat import aspawn, agather, AQueue, AWorker
if TYPE_CHECKING:
from .abc import Buffer
from .cursor_async import AsyncCursor
from .connection_async import AsyncConnection # noqa: F401
COPY_IN = pq.ExecStatus.COPY_IN
COPY_OUT = pq.ExecStatus.COPY_OUT
ACTIVE = pq.TransactionStatus.ACTIVE
class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation.
:param cursor: the cursor where the operation is performed.
:param binary: if `!True`, write binary format.
:param writer: the object to write to destination. If not specified, write
to the `!cursor` connection.
Choosing `!binary` is not necessary if the cursor has executed a
:sql:`COPY` operation, because the operation result describes the format
too. The parameter is useful when a `!Copy` object is created manually and
no operation is performed on the cursor, such as when using ``writer=``\\
`~psycopg.copy.FileWriter`.
"""
__module__ = "psycopg"
writer: AsyncWriter
def __init__(
self,
cursor: AsyncCursor[Any],
*,
binary: bool | None = None,
writer: AsyncWriter | None = None,
):
super().__init__(cursor, binary=binary)
if not writer:
writer = AsyncLibpqWriter(cursor)
self.writer = writer
self._write = writer.write
async def __aenter__(self) -> Self:
self._enter()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.finish(exc_val)
# End user sync interface
async def __aiter__(self) -> AsyncIterator[Buffer]:
"""Implement block-by-block iteration on :sql:`COPY TO`."""
while True:
data = await self.read()
if not data:
break
yield data
async def read(self) -> Buffer:
"""
Read an unparsed row after a :sql:`COPY TO` operation.
Return an empty string when the data is finished.
"""
return await self.connection.wait(self._read_gen())
async def rows(self) -> AsyncIterator[tuple[Any, ...]]:
"""
Iterate on the result of a :sql:`COPY TO` operation record by record.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
while True:
record = await self.read_row()
if record is None:
break
yield record
async def read_row(self) -> tuple[Any, ...] | None:
"""
Read a parsed row of data from a table after a :sql:`COPY TO` operation.
Return `!None` when the data is finished.
Note that the records returned will be tuples of unparsed strings or
bytes, unless data types are specified using `set_types()`.
"""
return await self.connection.wait(self._read_row_gen())
async def write(self, buffer: Buffer | str) -> None:
"""
Write a block of data to a table after a :sql:`COPY FROM` operation.
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
data = self.formatter.write(buffer)
if data:
await self._write(data)
async def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
data = self.formatter.write_row(row)
if data:
await self._write(data)
async def finish(self, exc: BaseException | None) -> None:
"""Terminate the copy operation and free the resources allocated.
You shouldn't need to call this function yourself: it is usually called
by exit. It is available if, despite what is documented, you end up
using the `Copy` object outside a block.
"""
if self._direction == COPY_IN:
data = self.formatter.end()
if data:
await self._write(data)
await self.writer.finish(exc)
self._finished = True
else:
if not exc:
return
if self._pgconn.transaction_status != ACTIVE:
# The server has already finished to send copy data. The connection
# is already in a good state.
return
# Throw a cancel to the server, then consume the rest of the copy data
# (which might or might not have been already transferred entirely to
# the client, so we won't necessary see the exception associated with
# canceling).
await self.connection._try_cancel()
await self.connection.wait(self._end_copy_out_gen())
class AsyncWriter(ABC):
"""
A class to write copy data somewhere (for async connections).
"""
@abstractmethod
async def write(self, data: Buffer) -> None:
"""Write some data to destination."""
...
async def finish(self, exc: BaseException | None = None) -> None:
"""
Called when write operations are finished.
If operations finished with an error, it will be passed to ``exc``.
"""
pass
class AsyncLibpqWriter(AsyncWriter):
"""
An `AsyncWriter` to write copy data to a Postgres database.
"""
__module__ = "psycopg.copy"
def __init__(self, cursor: AsyncCursor[Any]):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
async def write(self, data: Buffer) -> None:
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH))
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self.connection.wait(
copy_to(
self._pgconn, data[i : i + MAX_BUFFER_SIZE], flush=PREFER_FLUSH
)
)
async def finish(self, exc: BaseException | None = None) -> None:
bmsg: bytes | None
if exc:
msg = f"error from Python: {type(exc).__qualname__} - {exc}"
bmsg = msg.encode(self._pgconn._encoding, "replace")
else:
bmsg = None
try:
res = await self.connection.wait(copy_end(self._pgconn, bmsg))
# The QueryCanceled is expected if we sent an exception message to
# pgconn.put_copy_end(). The Python exception that generated that
# cancelling is more important, so don't clobber it.
except e.QueryCanceled:
if not bmsg:
raise
else:
self.cursor._results = [res]
class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
"""
`AsyncWriter` using a buffer to queue data to write.
`write()` returns immediately, so that the main thread can be CPU-bound
formatting messages, while a worker thread can be IO-bound waiting to write
on the connection.
"""
__module__ = "psycopg.copy"
def __init__(self, cursor: AsyncCursor[Any]):
super().__init__(cursor)
self._queue: AQueue[Buffer] = AQueue(maxsize=QUEUE_SIZE)
self._worker: AWorker | None = None
self._worker_error: BaseException | None = None
async def worker(self) -> None:
"""Push data to the server when available from the copy queue.
Terminate reading when the queue receives a false-y value, or in case
of error.
The function is designed to be run in a separate task.
"""
try:
while True:
data = await self._queue.get()
if not data:
break
await self.connection.wait(
copy_to(self._pgconn, data, flush=PREFER_FLUSH)
)
except BaseException as ex:
# Propagate the error to the main thread.
self._worker_error = ex
async def write(self, data: Buffer) -> None:
if not self._worker:
# warning: reference loop, broken by _write_end
self._worker = aspawn(self.worker)
# If the worker thread raies an exception, re-raise it to the caller.
if self._worker_error:
raise self._worker_error
if len(data) <= MAX_BUFFER_SIZE:
# Most used path: we don't need to split the buffer in smaller
# bits, so don't make a copy.
await self._queue.put(data)
else:
# Copy a buffer too large in chunks to avoid causing a memory
# error in the libpq, which may cause an infinite loop (#255).
for i in range(0, len(data), MAX_BUFFER_SIZE):
await self._queue.put(data[i : i + MAX_BUFFER_SIZE])
async def finish(self, exc: BaseException | None = None) -> None:
await self._queue.put(b"")
if self._worker:
await agather(self._worker)
self._worker = None # break reference loops if any
# Check if the worker thread raised any exception before terminating.
if self._worker_error:
raise self._worker_error
await super().finish(exc)

View File

@@ -0,0 +1,438 @@
"""
psycopg copy support
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import re
import sys
import struct
from abc import ABC, abstractmethod
from typing import Any, Generic, Match, Sequence, TYPE_CHECKING
from . import pq
from . import adapt
from . import errors as e
from .abc import Buffer, ConnectionType, PQGen, Transformer
from .pq.misc import connection_summary
from ._cmodule import _psycopg
from .generators import copy_from
if TYPE_CHECKING:
from ._cursor_base import BaseCursor
PY_TEXT = adapt.PyFormat.TEXT
PY_BINARY = adapt.PyFormat.BINARY
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
COPY_IN = pq.ExecStatus.COPY_IN
COPY_OUT = pq.ExecStatus.COPY_OUT
# Size of data to accumulate before sending it down the network. We fill a
# buffer this size field by field, and when it passes the threshold size
# we ship it, so it may end up being bigger than this.
BUFFER_SIZE = 32 * 1024
# Maximum data size we want to queue to send to the libpq copy. Sending a
# buffer too big to be handled can cause an infinite loop in the libpq
# (#255) so we want to split it in more digestable chunks.
MAX_BUFFER_SIZE = 4 * BUFFER_SIZE
# Note: making this buffer too large, e.g.
# MAX_BUFFER_SIZE = 1024 * 1024
# makes operations *way* slower! Probably triggering some quadraticity
# in the libpq memory management and data sending.
# Max size of the write queue of buffers. More than that copy will block
# Each buffer should be around BUFFER_SIZE size.
QUEUE_SIZE = 1024
# On certain systems, memmove seems particularly slow and flushing often is
# more performing than accumulating a larger buffer. See #746 for details.
PREFER_FLUSH = sys.platform == "darwin"
class BaseCopy(Generic[ConnectionType]):
"""
Base implementation for the copy user interface.
Two subclasses expose real methods with the sync/async differences.
The difference between the text and binary format is managed by two
different `Formatter` subclasses.
Writing (the I/O part) is implemented in the subclasses by a `Writer` or
`AsyncWriter` instance. Normally writing implies sending copy data to a
database, but a different writer might be chosen, e.g. to stream data into
a file for later use.
"""
formatter: Formatter
def __init__(
self,
cursor: BaseCursor[ConnectionType, Any],
*,
binary: bool | None = None,
):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
result = cursor.pgresult
if result:
self._direction = result.status
if self._direction != COPY_IN and self._direction != COPY_OUT:
raise e.ProgrammingError(
"the cursor should have performed a COPY operation;"
f" its status is {pq.ExecStatus(self._direction).name} instead"
)
else:
self._direction = COPY_IN
if binary is None:
binary = bool(result and result.binary_tuples)
tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor)
if binary:
self.formatter = BinaryFormatter(tx)
else:
self.formatter = TextFormatter(tx, encoding=self._pgconn._encoding)
self._finished = False
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
def _enter(self) -> None:
if self._finished:
raise TypeError("copy blocks can be used only once")
def set_types(self, types: Sequence[int | str]) -> None:
"""
Set the types expected in a COPY operation.
The types must be specified as a sequence of oid or PostgreSQL type
names (e.g. ``int4``, ``timestamptz[]``).
This operation overcomes the lack of metadata returned by PostgreSQL
when a COPY operation begins:
- On :sql:`COPY TO`, `!set_types()` allows to specify what types the
operation returns. If `!set_types()` is not used, the data will be
returned as unparsed strings or bytes instead of Python objects.
- On :sql:`COPY FROM`, `!set_types()` allows to choose what type the
database expects. This is especially useful in binary copy, because
PostgreSQL will apply no cast rule.
"""
registry = self.cursor.adapters.types
oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
if self._direction == COPY_IN:
self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
else:
self.formatter.transformer.set_loader_types(oids, self.formatter.format)
# High level copy protocol generators (state change of the Copy object)
def _read_gen(self) -> PQGen[Buffer]:
if self._finished:
return memoryview(b"")
res = yield from copy_from(self._pgconn)
if isinstance(res, memoryview):
return res
# res is the final PGresult
self._finished = True
# This result is a COMMAND_OK which has info about the number of rows
# returned, but not about the columns, which is instead an information
# that was received on the COPY_OUT result at the beginning of COPY.
# So, don't replace the results in the cursor, just update the rowcount.
nrows = res.command_tuples
self.cursor._rowcount = nrows if nrows is not None else -1
return memoryview(b"")
def _read_row_gen(self) -> PQGen[tuple[Any, ...] | None]:
data = yield from self._read_gen()
if not data:
return None
row = self.formatter.parse_row(data)
if row is None:
# Get the final result to finish the copy operation
yield from self._read_gen()
self._finished = True
return None
return row
def _end_copy_out_gen(self) -> PQGen[None]:
try:
while (yield from self._read_gen()):
pass
except e.QueryCanceled:
pass
class Formatter(ABC):
"""
A class which understand a copy format (text, binary).
"""
format: pq.Format
def __init__(self, transformer: Transformer):
self.transformer = transformer
self._write_buffer = bytearray()
self._row_mode = False # true if the user is using write_row()
@abstractmethod
def parse_row(self, data: Buffer) -> tuple[Any, ...] | None: ...
@abstractmethod
def write(self, buffer: Buffer | str) -> Buffer: ...
@abstractmethod
def write_row(self, row: Sequence[Any]) -> Buffer: ...
@abstractmethod
def end(self) -> Buffer: ...
class TextFormatter(Formatter):
format = TEXT
def __init__(self, transformer: Transformer, encoding: str = "utf-8"):
super().__init__(transformer)
self._encoding = encoding
def parse_row(self, data: Buffer) -> tuple[Any, ...] | None:
if data:
return parse_row_text(data, self.transformer)
else:
return None
def write(self, buffer: Buffer | str) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
format_row_text(row, self.transformer, self._write_buffer)
if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
return b""
def end(self) -> Buffer:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
def _ensure_bytes(self, data: Buffer | str) -> Buffer:
if isinstance(data, str):
return data.encode(self._encoding)
else:
# Assume, for simplicity, that the user is not passing stupid
# things to the write function. If that's the case, things
# will fail downstream.
return data
class BinaryFormatter(Formatter):
format = BINARY
def __init__(self, transformer: Transformer):
super().__init__(transformer)
self._signature_sent = False
def parse_row(self, data: Buffer) -> tuple[Any, ...] | None:
if not self._signature_sent:
if data[: len(_binary_signature)] != _binary_signature:
raise e.DataError(
"binary copy doesn't start with the expected signature"
)
self._signature_sent = True
data = data[len(_binary_signature) :]
elif data == _binary_trailer:
return None
return parse_row_binary(data, self.transformer)
def write(self, buffer: Buffer | str) -> Buffer:
data = self._ensure_bytes(buffer)
self._signature_sent = True
return data
def write_row(self, row: Sequence[Any]) -> Buffer:
# Note down that we are writing in row mode: it means we will have
# to take care of the end-of-copy marker too
self._row_mode = True
if not self._signature_sent:
self._write_buffer += _binary_signature
self._signature_sent = True
format_row_binary(row, self.transformer, self._write_buffer)
if len(self._write_buffer) > BUFFER_SIZE:
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
else:
return b""
def end(self) -> Buffer:
# If we have sent no data we need to send the signature
# and the trailer
if not self._signature_sent:
self._write_buffer += _binary_signature
self._write_buffer += _binary_trailer
elif self._row_mode:
# if we have sent data already, we have sent the signature
# too (either with the first row, or we assume that in
# block mode the signature is included).
# Write the trailer only if we are sending rows (with the
# assumption that who is copying binary data is sending the
# whole format).
self._write_buffer += _binary_trailer
buffer, self._write_buffer = self._write_buffer, bytearray()
return buffer
def _ensure_bytes(self, data: Buffer | str) -> Buffer:
if isinstance(data, str):
raise TypeError("cannot copy str data in binary mode: use bytes instead")
else:
# Assume, for simplicity, that the user is not passing stupid
# things to the write function. If that's the case, things
# will fail downstream.
return data
def _format_row_text(
row: Sequence[Any], tx: Transformer, out: bytearray | None = None
) -> bytearray:
"""Convert a row of objects to the data to send for copy."""
if out is None:
out = bytearray()
if not row:
out += b"\n"
return out
adapted = tx.dump_sequence(row, [PY_TEXT] * len(row))
for b in adapted:
out += _dump_re.sub(_dump_sub, b) if b is not None else rb"\N"
out += b"\t"
out[-1:] = b"\n"
return out
def _format_row_binary(
row: Sequence[Any], tx: Transformer, out: bytearray | None = None
) -> bytearray:
"""Convert a row of objects to the data to send for binary copy."""
if out is None:
out = bytearray()
out += _pack_int2(len(row))
adapted = tx.dump_sequence(row, [PY_BINARY] * len(row))
for b in adapted:
if b is not None:
out += _pack_int4(len(b))
out += b
else:
out += _binary_null
return out
def _parse_row_text(data: Buffer, tx: Transformer) -> tuple[Any, ...]:
if not isinstance(data, bytes):
data = bytes(data)
fields = data.split(b"\t")
fields[-1] = fields[-1][:-1] # drop \n
row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields]
return tx.load_sequence(row)
def _parse_row_binary(data: Buffer, tx: Transformer) -> tuple[Any, ...]:
row: list[Buffer | None] = []
nfields = _unpack_int2(data, 0)[0]
pos = 2
for i in range(nfields):
length = _unpack_int4(data, pos)[0]
pos += 4
if length >= 0:
row.append(data[pos : pos + length])
pos += length
else:
row.append(None)
return tx.load_sequence(row)
_pack_int2 = struct.Struct("!h").pack
_pack_int4 = struct.Struct("!i").pack
_unpack_int2 = struct.Struct("!h").unpack_from
_unpack_int4 = struct.Struct("!i").unpack_from
_binary_signature = (
b"PGCOPY\n\xff\r\n\0" # Signature
b"\x00\x00\x00\x00" # flags
b"\x00\x00\x00\x00" # extra length
)
_binary_trailer = b"\xff\xff"
_binary_null = b"\xff\xff\xff\xff"
_dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]")
_dump_repl = {
b"\b": b"\\b",
b"\t": b"\\t",
b"\n": b"\\n",
b"\v": b"\\v",
b"\f": b"\\f",
b"\r": b"\\r",
b"\\": b"\\\\",
}
def _dump_sub(m: Match[bytes], __map: dict[bytes, bytes] = _dump_repl) -> bytes:
return __map[m.group(0)]
_load_re = re.compile(b"\\\\[btnvfr\\\\]")
_load_repl = {v: k for k, v in _dump_repl.items()}
def _load_sub(m: Match[bytes], __map: dict[bytes, bytes] = _load_repl) -> bytes:
return __map[m.group(0)]
# Override functions with fast versions if available
if _psycopg:
format_row_text = _psycopg.format_row_text
format_row_binary = _psycopg.format_row_binary
parse_row_text = _psycopg.parse_row_text
parse_row_binary = _psycopg.parse_row_binary
else:
format_row_text = _format_row_text
format_row_binary = _format_row_binary
parse_row_text = _parse_row_text
parse_row_binary = _parse_row_binary

View File

@@ -0,0 +1,627 @@
"""
Psycopg BaseCursor object
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from functools import partial
from typing import Any, Generic, Iterable, NoReturn, Sequence
from typing import TYPE_CHECKING
from . import pq
from . import adapt
from . import errors as e
from .abc import ConnectionType, Query, Params, PQGen
from .rows import Row, RowMaker
from ._capabilities import capabilities
from ._column import Column
from .pq.misc import connection_summary
from ._queries import PostgresQuery, PostgresClientQuery
from ._preparing import Prepare
from .generators import execute, fetch, send
if TYPE_CHECKING:
from .abc import Transformer
from .pq.abc import PGconn, PGresult
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
EMPTY_QUERY = pq.ExecStatus.EMPTY_QUERY
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
COPY_OUT = pq.ExecStatus.COPY_OUT
COPY_IN = pq.ExecStatus.COPY_IN
COPY_BOTH = pq.ExecStatus.COPY_BOTH
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
TUPLES_CHUNK = pq.ExecStatus.TUPLES_CHUNK
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
ACTIVE = pq.TransactionStatus.ACTIVE
class BaseCursor(Generic[ConnectionType, Row]):
__slots__ = """
_conn format _adapters arraysize _closed _results pgresult _pos
_iresult _rowcount _query _tx _last_query _row_factory _make_row
_pgconn _execmany_returning
__weakref__
""".split()
_tx: Transformer
_make_row: RowMaker[Row]
_pgconn: PGconn
_query_cls: type[PostgresQuery] = PostgresQuery
def __init__(self, connection: ConnectionType):
self._conn = connection
self.format = TEXT
self._pgconn = connection.pgconn
self._adapters = adapt.AdaptersMap(connection.adapters)
self.arraysize = 1
self._closed = False
self._last_query: Query | None = None
self._reset()
def _reset(self, reset_query: bool = True) -> None:
self._results: list[PGresult] = []
self.pgresult: PGresult | None = None
self._pos = 0
self._iresult = 0
self._rowcount = -1
self._query: PostgresQuery | None
# None if executemany() not executing, True/False according to returning state
self._execmany_returning: bool | None = None
if reset_query:
self._query = None
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._pgconn)
if self._closed:
status = "closed"
elif self.pgresult:
status = pq.ExecStatus(self.pgresult.status).name
else:
status = "no result"
return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
@property
def connection(self) -> ConnectionType:
"""The connection this cursor is using."""
return self._conn
@property
def adapters(self) -> adapt.AdaptersMap:
return self._adapters
@property
def closed(self) -> bool:
"""`True` if the cursor is closed."""
return self._closed
@property
def description(self) -> list[Column] | None:
"""
A list of `Column` objects describing the current resultset.
`!None` if the current resultset didn't return tuples.
"""
res = self.pgresult
# We return columns if we have nfields, but also if we don't but
# the query said we got tuples (mostly to handle the super useful
# query "SELECT ;"
if res and (
res.nfields
or res.status == TUPLES_OK
or res.status == SINGLE_TUPLE
or res.status == TUPLES_CHUNK
):
return [Column(self, i) for i in range(res.nfields)]
else:
return None
@property
def rowcount(self) -> int:
"""Number of records affected by the precedent operation."""
return self._rowcount
@property
def rownumber(self) -> int | None:
"""Index of the next row to fetch in the current result.
`!None` if there is no result to fetch.
"""
tuples = self.pgresult and self.pgresult.status == TUPLES_OK
return self._pos if tuples else None
def setinputsizes(self, sizes: Sequence[Any]) -> None:
# no-op
pass
def setoutputsize(self, size: Any, column: int | None = None) -> None:
# no-op
pass
def nextset(self) -> bool | None:
"""
Move to the result set of the next query executed through `executemany()`
or to the next result set if `execute()` returned more than one.
Return `!True` if a new result is available, which will be the one
methods `!fetch*()` will operate on.
"""
if self._iresult < len(self._results) - 1:
self._select_current_result(self._iresult + 1)
return True
else:
return None
@property
def statusmessage(self) -> str | None:
"""
The command status tag from the last SQL command executed.
`!None` if the cursor doesn't have a result available.
"""
msg = self.pgresult.command_status if self.pgresult else None
return msg.decode() if msg else None
def _make_row_maker(self) -> RowMaker[Row]:
raise NotImplementedError
#
# Generators for the high level operations on the cursor
#
# Like for sync/async connections, these are implemented as generators
# so that different concurrency strategies (threads,asyncio) can use their
# own way of waiting (or better, `connection.wait()`).
#
def _execute_gen(
self,
query: Query,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool | None = None,
) -> PQGen[None]:
"""Generator implementing `Cursor.execute()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
yield from self._maybe_prepare_gen(pgq, prepare=prepare, binary=binary)
if self._conn._pipeline:
yield from self._conn._pipeline._communicate_gen()
self._last_query = query
yield from self._conn._prepared.maintain_gen(self._conn)
def _executemany_gen_pipeline(
self, query: Query, params_seq: Iterable[Params], returning: bool
) -> PQGen[None]:
"""
Generator implementing `Cursor.executemany()` with pipelines available.
"""
pipeline = self._conn._pipeline
assert pipeline
yield from self._start_query(query)
if not returning:
self._rowcount = 0
assert self._execmany_returning is None
self._execmany_returning = returning
first = True
for params in params_seq:
if first:
pgq = self._convert_query(query, params)
self._query = pgq
first = False
else:
pgq.dump(params)
yield from self._maybe_prepare_gen(pgq, prepare=True)
yield from pipeline._communicate_gen()
self._last_query = query
if returning:
yield from pipeline._fetch_gen(flush=True)
yield from self._conn._prepared.maintain_gen(self._conn)
def _executemany_gen_no_pipeline(
self, query: Query, params_seq: Iterable[Params], returning: bool
) -> PQGen[None]:
"""
Generator implementing `Cursor.executemany()` with pipelines not available.
"""
yield from self._start_query(query)
if not returning:
self._rowcount = 0
assert self._execmany_returning is None
self._execmany_returning = returning
first = True
for params in params_seq:
if first:
pgq = self._convert_query(query, params)
self._query = pgq
first = False
else:
pgq.dump(params)
yield from self._maybe_prepare_gen(pgq, prepare=True)
self._last_query = query
yield from self._conn._prepared.maintain_gen(self._conn)
def _maybe_prepare_gen(
self,
pgq: PostgresQuery,
*,
prepare: bool | None = None,
binary: bool | None = None,
) -> PQGen[None]:
# Check if the query is prepared or needs preparing
prep, name = self._get_prepared(pgq, prepare)
if prep is Prepare.NO:
# The query must be executed without preparing
self._execute_send(pgq, binary=binary)
else:
# If the query is not already prepared, prepare it.
if prep is Prepare.SHOULD:
self._send_prepare(name, pgq)
if not self._conn._pipeline:
(result,) = yield from execute(self._pgconn)
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
# Then execute it.
self._send_query_prepared(name, pgq, binary=binary)
# Update the prepare state of the query.
# If an operation requires to flush our prepared statements cache,
# it will be added to the maintenance commands to execute later.
key = self._conn._prepared.maybe_add_to_cache(pgq, prep, name)
if self._conn._pipeline:
queued = None
if key is not None:
queued = (key, prep, name)
self._conn._pipeline.result_queue.append((self, queued))
return
# run the query
results = yield from execute(self._pgconn)
if key is not None:
self._conn._prepared.validate(key, prep, name, results)
self._check_results(results)
self._set_results(results)
def _get_prepared(
self, pgq: PostgresQuery, prepare: bool | None = None
) -> tuple[Prepare, bytes]:
return self._conn._prepared.get(pgq, prepare)
def _stream_send_gen(
self,
query: Query,
params: Params | None = None,
*,
binary: bool | None = None,
size: int,
) -> PQGen[None]:
"""Generator to send the query for `Cursor.stream()`."""
yield from self._start_query(query)
pgq = self._convert_query(query, params)
self._execute_send(pgq, binary=binary, force_extended=True)
if size < 1:
raise ValueError("size must be >= 1")
elif size == 1:
self._pgconn.set_single_row_mode()
else:
capabilities.has_stream_chunked(check=True)
self._pgconn.set_chunked_rows_mode(size)
self._last_query = query
yield from send(self._pgconn)
def _stream_fetchone_gen(self, first: bool) -> PQGen[PGresult | None]:
res = yield from fetch(self._pgconn)
if res is None:
return None
status = res.status
if status == SINGLE_TUPLE or status == TUPLES_CHUNK:
self.pgresult = res
self._tx.set_pgresult(res, set_loaders=first)
if first:
self._make_row = self._make_row_maker()
return res
elif status == TUPLES_OK or status == COMMAND_OK:
# End of single row results
while res:
res = yield from fetch(self._pgconn)
if status != TUPLES_OK:
raise e.ProgrammingError(
"the operation in stream() didn't produce a result"
)
return None
else:
# Errors, unexpected values
return self._raise_for_result(res)
def _start_query(self, query: Query | None = None) -> PQGen[None]:
"""Generator to start the processing of a query.
It is implemented as generator because it may send additional queries,
such as `begin`.
"""
if self.closed:
raise e.InterfaceError("the cursor is closed")
self._reset()
if not self._last_query or (self._last_query is not query):
self._last_query = None
self._tx = adapt.Transformer(self)
yield from self._conn._start_query()
def _start_copy_gen(
self, statement: Query, params: Params | None = None
) -> PQGen[None]:
"""Generator implementing sending a command for `Cursor.copy()."""
# The connection gets in an unrecoverable state if we attempt COPY in
# pipeline mode. Forbid it explicitly.
if self._conn._pipeline:
raise e.NotSupportedError("COPY cannot be used in pipeline mode")
yield from self._start_query()
# Merge the params client-side
if params:
pgq = PostgresClientQuery(self._tx)
pgq.convert(statement, params)
statement = pgq.query
query = self._convert_query(statement)
self._execute_send(query, binary=False)
results = yield from execute(self._pgconn)
if len(results) != 1:
raise e.ProgrammingError("COPY cannot be mixed with other operations")
self._check_copy_result(results[0])
self._set_results(results)
def _execute_send(
self,
query: PostgresQuery,
*,
force_extended: bool = False,
binary: bool | None = None,
) -> None:
"""
Implement part of execute() before waiting common to sync and async.
This is not a generator, but a normal non-blocking function.
"""
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
self._query = query
if self._conn._pipeline:
# In pipeline mode always use PQsendQueryParams - see #314
# Multiple statements in the same query are not allowed anyway.
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_query_params,
query.query,
query.params,
param_formats=query.formats,
param_types=query.types,
result_format=fmt,
)
)
elif force_extended or query.params or fmt == BINARY:
self._pgconn.send_query_params(
query.query,
query.params,
param_formats=query.formats,
param_types=query.types,
result_format=fmt,
)
else:
# If we can, let's use simple query protocol,
# as it can execute more than one statement in a single query.
self._pgconn.send_query(query.query)
def _convert_query(
self, query: Query, params: Params | None = None
) -> PostgresQuery:
pgq = self._query_cls(self._tx)
pgq.convert(query, params)
return pgq
def _check_results(self, results: list[PGresult]) -> None:
"""
Verify that the results of a query are valid.
Verify that the query returned at least one result and that they all
represent a valid result from the database.
"""
if not results:
raise e.InternalError("got no result from the query")
for res in results:
status = res.status
if status != TUPLES_OK and status != COMMAND_OK and status != EMPTY_QUERY:
self._raise_for_result(res)
def _raise_for_result(self, result: PGresult) -> NoReturn:
"""
Raise an appropriate error message for an unexpected database result
"""
status = result.status
if status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
raise e.ProgrammingError(
"COPY cannot be used with this method; use copy() instead"
)
else:
raise e.InternalError(
"unexpected result status from query:" f" {pq.ExecStatus(status).name}"
)
def _select_current_result(self, i: int, format: pq.Format | None = None) -> None:
"""
Select one of the results in the cursor as the active one.
"""
self._iresult = i
res = self.pgresult = self._results[i]
# Note: the only reason to override format is to correctly set
# binary loaders on server-side cursors, because send_describe_portal
# only returns a text result.
self._tx.set_pgresult(res, format=format)
self._pos = 0
if res.status == TUPLES_OK:
self._rowcount = self.pgresult.ntuples
# COPY_OUT has never info about nrows. We need such result for the
# columns in order to return a `description`, but not overwrite the
# cursor rowcount (which was set by the Copy object).
elif res.status != COPY_OUT:
nrows = self.pgresult.command_tuples
self._rowcount = nrows if nrows is not None else -1
self._make_row = self._make_row_maker()
def _set_results(self, results: list[PGresult]) -> None:
if self._execmany_returning is None:
# Received from execute()
self._results[:] = results
self._select_current_result(0)
else:
# Received from executemany()
if self._execmany_returning:
first_batch = not self._results
self._results.extend(results)
if first_batch:
self._select_current_result(0)
else:
# In non-returning case, set rowcount to the cumulated number of
# rows of executed queries.
for res in results:
self._rowcount += res.command_tuples or 0
def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
if self._conn._pipeline:
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_prepare,
name,
query.query,
param_types=query.types,
)
)
self._conn._pipeline.result_queue.append(None)
else:
self._pgconn.send_prepare(name, query.query, param_types=query.types)
def _send_query_prepared(
self, name: bytes, pgq: PostgresQuery, *, binary: bool | None = None
) -> None:
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
if self._conn._pipeline:
self._conn._pipeline.command_queue.append(
partial(
self._pgconn.send_query_prepared,
name,
pgq.params,
param_formats=pgq.formats,
result_format=fmt,
)
)
else:
self._pgconn.send_query_prepared(
name, pgq.params, param_formats=pgq.formats, result_format=fmt
)
def _check_result_for_fetch(self) -> None:
if self.closed:
raise e.InterfaceError("the cursor is closed")
res = self.pgresult
if not res:
raise e.ProgrammingError("no result available")
status = res.status
if status == TUPLES_OK:
return
elif status == FATAL_ERROR:
raise e.error_from_result(res, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
else:
raise e.ProgrammingError("the last operation didn't produce a result")
def _check_copy_result(self, result: PGresult) -> None:
"""
Check that the value returned in a copy() operation is a legit COPY.
"""
status = result.status
if status == COPY_IN or status == COPY_OUT:
return
elif status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
else:
raise e.ProgrammingError(
"copy() should be used only with COPY ... TO STDOUT or COPY ..."
f" FROM STDIN statements, got {pq.ExecStatus(status).name}"
)
def _scroll(self, value: int, mode: str) -> None:
self._check_result_for_fetch()
assert self.pgresult
if mode == "relative":
newpos = self._pos + value
elif mode == "absolute":
newpos = value
else:
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
if not 0 <= newpos < self.pgresult.ntuples:
raise IndexError("position out of bound")
self._pos = newpos
def _close(self) -> None:
"""Non-blocking part of closing. Common to sync/async."""
# Don't reset the query because it may be useful to investigate after
# an error.
self._reset(reset_query=False)
self._closed = True
@property
def _encoding(self) -> str:
return self._pgconn._encoding

View File

@@ -0,0 +1,247 @@
# type: ignore # dnspython is currently optional and mypy fails if missing
"""
DNS query support
"""
# Copyright (C) 2021 The Psycopg Team
from __future__ import annotations
import os
import re
import warnings
from random import randint
from typing import Any, DefaultDict, NamedTuple, Sequence, TYPE_CHECKING
from collections import defaultdict
try:
from dns.resolver import Resolver, Cache
from dns.asyncresolver import Resolver as AsyncResolver
from dns.exception import DNSException
except ImportError:
raise ImportError(
"the module psycopg._dns requires the package 'dnspython' installed"
)
from . import errors as e
from . import conninfo
if TYPE_CHECKING:
from dns.rdtypes.IN.SRV import SRV
resolver = Resolver()
resolver.cache = Cache()
async_resolver = AsyncResolver()
async_resolver.cache = Cache()
async def resolve_hostaddr_async(params: dict[str, Any]) -> dict[str, Any]:
"""
Perform async DNS lookup of the hosts and return a new params dict.
.. deprecated:: 3.1
The use of this function is not necessary anymore, because
`psycopg.AsyncConnection.connect()` performs non-blocking name
resolution automatically.
"""
warnings.warn(
"from psycopg 3.1, resolve_hostaddr_async() is not needed anymore",
DeprecationWarning,
)
hosts: list[str] = []
hostaddrs: list[str] = []
ports: list[str] = []
for attempt in await conninfo.conninfo_attempts_async(params):
if attempt.get("host") is not None:
hosts.append(attempt["host"])
if attempt.get("hostaddr") is not None:
hostaddrs.append(attempt["hostaddr"])
if attempt.get("port") is not None:
ports.append(str(attempt["port"]))
out = params.copy()
shosts = ",".join(hosts)
if shosts:
out["host"] = shosts
shostaddrs = ",".join(hostaddrs)
if shostaddrs:
out["hostaddr"] = shostaddrs
sports = ",".join(ports)
if ports:
out["port"] = sports
return out
def resolve_srv(params: dict[str, Any]) -> dict[str, Any]:
"""Apply SRV DNS lookup as defined in :RFC:`2782`."""
return Rfc2782Resolver().resolve(params)
async def resolve_srv_async(params: dict[str, Any]) -> dict[str, Any]:
"""Async equivalent of `resolve_srv()`."""
return await Rfc2782Resolver().resolve_async(params)
class HostPort(NamedTuple):
host: str
port: str
totry: bool = False
target: str | None = None
class Rfc2782Resolver:
"""Implement SRV RR Resolution as per RFC 2782
The class is organised to minimise code duplication between the sync and
the async paths.
"""
re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
def resolve(self, params: dict[str, Any]) -> dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
attempts = self._get_attempts(params)
if not attempts:
return params
hps = []
for hp in attempts:
if hp.totry:
hps.extend(self._resolve_srv(hp))
else:
hps.append(hp)
return self._return_params(params, hps)
async def resolve_async(self, params: dict[str, Any]) -> dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
attempts = self._get_attempts(params)
if not attempts:
return params
hps = []
for hp in attempts:
if hp.totry:
hps.extend(await self._resolve_srv_async(hp))
else:
hps.append(hp)
return self._return_params(params, hps)
def _get_attempts(self, params: dict[str, Any]) -> list[HostPort]:
"""
Return the list of host, and for each host if SRV lookup must be tried.
Return an empty list if no lookup is requested.
"""
# If hostaddr is defined don't do any resolution.
if params.get("hostaddr", os.environ.get("PGHOSTADDR", "")):
return []
host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
hosts_in = host_arg.split(",")
port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
ports_in = port_arg.split(",")
if len(ports_in) == 1:
# If only one port is specified, it applies to all the hosts.
ports_in *= len(hosts_in)
if len(ports_in) != len(hosts_in):
# ProgrammingError would have been more appropriate, but this is
# what the raise if the libpq fails connect in the same case.
raise e.OperationalError(
f"cannot match {len(hosts_in)} hosts with {len(ports_in)} port numbers"
)
out = []
srv_found = False
for host, port in zip(hosts_in, ports_in):
m = self.re_srv_rr.match(host)
if m or port.lower() == "srv":
srv_found = True
target = m.group("target") if m else None
hp = HostPort(host=host, port=port, totry=True, target=target)
else:
hp = HostPort(host=host, port=port)
out.append(hp)
return out if srv_found else []
def _resolve_srv(self, hp: HostPort) -> list[HostPort]:
try:
ans = resolver.resolve(hp.host, "SRV")
except DNSException:
ans = ()
return self._get_solved_entries(hp, ans)
async def _resolve_srv_async(self, hp: HostPort) -> list[HostPort]:
try:
ans = await async_resolver.resolve(hp.host, "SRV")
except DNSException:
ans = ()
return self._get_solved_entries(hp, ans)
def _get_solved_entries(
self, hp: HostPort, entries: Sequence[SRV]
) -> list[HostPort]:
if not entries:
# No SRV entry found. Delegate the libpq a QNAME=target lookup
if hp.target and hp.port.lower() != "srv":
return [HostPort(host=hp.target, port=hp.port)]
else:
return []
# If there is precisely one SRV RR, and its Target is "." (the root
# domain), abort.
if len(entries) == 1 and str(entries[0].target) == ".":
return []
return [
HostPort(host=str(entry.target).rstrip("."), port=str(entry.port))
for entry in self.sort_rfc2782(entries)
]
def _return_params(
self, params: dict[str, Any], hps: list[HostPort]
) -> dict[str, Any]:
if not hps:
# Nothing found, we ended up with an empty list
raise e.OperationalError("no host found after SRV RR lookup")
out = params.copy()
out["host"] = ",".join(hp.host for hp in hps)
out["port"] = ",".join(str(hp.port) for hp in hps)
return out
def sort_rfc2782(self, ans: Sequence[SRV]) -> list[SRV]:
"""
Implement the priority/weight ordering defined in RFC 2782.
"""
# Divide the entries by priority:
priorities: DefaultDict[int, list[SRV]] = defaultdict(list)
out: list[SRV] = []
for entry in ans:
priorities[entry.priority].append(entry)
for pri, entries in sorted(priorities.items()):
if len(entries) == 1:
out.append(entries[0])
continue
entries.sort(key=lambda ent: ent.weight)
total_weight = sum(ent.weight for ent in entries)
while entries:
r = randint(0, total_weight)
csum = 0
for i, ent in enumerate(entries):
csum += ent.weight
if csum >= r:
break
out.append(ent)
total_weight -= ent.weight
del entries[i]
return out

View File

@@ -0,0 +1,154 @@
"""
Mappings between PostgreSQL and Python encodings.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import re
import string
import codecs
from typing import Any, TYPE_CHECKING
from .pq._enums import ConnStatus
from .errors import NotSupportedError
from ._compat import cache
if TYPE_CHECKING:
from ._connection_base import BaseConnection
OK = ConnStatus.OK
_py_codecs = {
"BIG5": "big5",
"EUC_CN": "gb2312",
"EUC_JIS_2004": "euc_jis_2004",
"EUC_JP": "euc_jp",
"EUC_KR": "euc_kr",
# "EUC_TW": not available in Python
"GB18030": "gb18030",
"GBK": "gbk",
"ISO_8859_5": "iso8859-5",
"ISO_8859_6": "iso8859-6",
"ISO_8859_7": "iso8859-7",
"ISO_8859_8": "iso8859-8",
"JOHAB": "johab",
"KOI8R": "koi8-r",
"KOI8U": "koi8-u",
"LATIN1": "iso8859-1",
"LATIN10": "iso8859-16",
"LATIN2": "iso8859-2",
"LATIN3": "iso8859-3",
"LATIN4": "iso8859-4",
"LATIN5": "iso8859-9",
"LATIN6": "iso8859-10",
"LATIN7": "iso8859-13",
"LATIN8": "iso8859-14",
"LATIN9": "iso8859-15",
# "MULE_INTERNAL": not available in Python
"SHIFT_JIS_2004": "shift_jis_2004",
"SJIS": "shift_jis",
# this actually means no encoding, see PostgreSQL docs
# it is special-cased by the text loader.
"SQL_ASCII": "ascii",
"UHC": "cp949",
"UTF8": "utf-8",
"WIN1250": "cp1250",
"WIN1251": "cp1251",
"WIN1252": "cp1252",
"WIN1253": "cp1253",
"WIN1254": "cp1254",
"WIN1255": "cp1255",
"WIN1256": "cp1256",
"WIN1257": "cp1257",
"WIN1258": "cp1258",
"WIN866": "cp866",
"WIN874": "cp874",
}
py_codecs: dict[bytes, str] = {}
py_codecs.update((k.encode(), v) for k, v in _py_codecs.items())
# Add an alias without underscore, for lenient lookups
py_codecs.update(
(k.replace("_", "").encode(), v) for k, v in _py_codecs.items() if "_" in k
)
pg_codecs = {v: k.encode() for k, v in _py_codecs.items()}
def conn_encoding(conn: BaseConnection[Any] | None) -> str:
"""
Return the Python encoding name of a psycopg connection.
Default to utf8 if the connection has no encoding info.
"""
return conn.pgconn._encoding if conn else "utf-8"
def conninfo_encoding(conninfo: str) -> str:
"""
Return the Python encoding name passed in a conninfo string. Default to utf8.
Because the input is likely to come from the user and not normalised by the
server, be somewhat lenient (non-case-sensitive lookup, ignore noise chars).
"""
from .conninfo import conninfo_to_dict
params = conninfo_to_dict(conninfo)
pgenc = params.get("client_encoding")
if pgenc:
try:
return pg2pyenc(str(pgenc).encode())
except NotSupportedError:
pass
return "utf-8"
@cache
def py2pgenc(name: str) -> bytes:
"""Convert a Python encoding name to PostgreSQL encoding name.
Raise LookupError if the Python encoding is unknown.
"""
return pg_codecs[codecs.lookup(name).name]
@cache
def pg2pyenc(name: bytes) -> str:
"""Convert a PostgreSQL encoding name to Python encoding name.
Raise NotSupportedError if the PostgreSQL encoding is not supported by
Python.
"""
try:
return py_codecs[name.replace(b"-", b"").replace(b"_", b"").upper()]
except KeyError:
sname = name.decode("utf8", "replace")
raise NotSupportedError(f"codec not available in Python: {sname!r}")
def _as_python_identifier(s: str, prefix: str = "f") -> str:
"""
Reduce a string to a valid Python identifier.
Replace all non-valid chars with '_' and prefix the value with `!prefix` if
the first letter is an '_'.
"""
if not s.isidentifier():
if s[0] in "1234567890":
s = prefix + s
if not s.isidentifier():
s = _re_clean.sub("_", s)
# namedtuple fields cannot start with underscore. So...
if s[0] == "_":
s = prefix + s
return s
_re_clean = re.compile(
f"[^{string.ascii_lowercase}{string.ascii_uppercase}{string.digits}_]"
)

View File

@@ -0,0 +1,80 @@
"""
Enum values for psycopg
These values are defined by us and are not necessarily dependent on
libpq-defined enums.
"""
# Copyright (C) 2020 The Psycopg Team
from enum import Enum, IntEnum
from selectors import EVENT_READ, EVENT_WRITE
from . import pq
class Wait(IntEnum):
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
class Ready(IntEnum):
NONE = 0
R = EVENT_READ
W = EVENT_WRITE
RW = EVENT_READ | EVENT_WRITE
class PyFormat(str, Enum):
"""
Enum representing the format wanted for a query argument.
The value `AUTO` allows psycopg to choose the best format for a certain
parameter.
"""
__module__ = "psycopg.adapt"
AUTO = "s"
"""Automatically chosen (``%s`` placeholder)."""
TEXT = "t"
"""Text parameter (``%t`` placeholder)."""
BINARY = "b"
"""Binary parameter (``%b`` placeholder)."""
@classmethod
def from_pq(cls, fmt: pq.Format) -> "PyFormat":
return _pg2py[fmt]
@classmethod
def as_pq(cls, fmt: "PyFormat") -> pq.Format:
return _py2pg[fmt]
class IsolationLevel(IntEnum):
"""
Enum representing the isolation level for a transaction.
"""
__module__ = "psycopg"
READ_UNCOMMITTED = 1
""":sql:`READ UNCOMMITTED` isolation level."""
READ_COMMITTED = 2
""":sql:`READ COMMITTED` isolation level."""
REPEATABLE_READ = 3
""":sql:`REPEATABLE READ` isolation level."""
SERIALIZABLE = 4
""":sql:`SERIALIZABLE` isolation level."""
_py2pg = {
PyFormat.TEXT: pq.Format.TEXT,
PyFormat.BINARY: pq.Format.BINARY,
}
_pg2py = {
pq.Format.TEXT: PyFormat.TEXT,
pq.Format.BINARY: PyFormat.BINARY,
}

View File

@@ -0,0 +1,97 @@
"""
PostgreSQL known type OIDs
This is an internal module. Types are publicly exposed by
`psycopg.postgres.types`. This module is only used to know the OIDs at import
time and avoid circular import problems.
"""
# Copyright (C) 2020 The Psycopg Team
# A couple of special cases used a bit everywhere.
INVALID_OID = 0
TEXT_ARRAY_OID = 1009
# Use tools/update_oids.py to update this data.
# autogenerated: start
# Generated from PostgreSQL 16.2
ACLITEM_OID = 1033
BIT_OID = 1560
BOOL_OID = 16
BOX_OID = 603
BPCHAR_OID = 1042
BYTEA_OID = 17
CHAR_OID = 18
CID_OID = 29
CIDR_OID = 650
CIRCLE_OID = 718
DATE_OID = 1082
DATEMULTIRANGE_OID = 4535
DATERANGE_OID = 3912
FLOAT4_OID = 700
FLOAT8_OID = 701
GTSVECTOR_OID = 3642
INET_OID = 869
INT2_OID = 21
INT2VECTOR_OID = 22
INT4_OID = 23
INT4MULTIRANGE_OID = 4451
INT4RANGE_OID = 3904
INT8_OID = 20
INT8MULTIRANGE_OID = 4536
INT8RANGE_OID = 3926
INTERVAL_OID = 1186
JSON_OID = 114
JSONB_OID = 3802
JSONPATH_OID = 4072
LINE_OID = 628
LSEG_OID = 601
MACADDR_OID = 829
MACADDR8_OID = 774
MONEY_OID = 790
NAME_OID = 19
NUMERIC_OID = 1700
NUMMULTIRANGE_OID = 4532
NUMRANGE_OID = 3906
OID_OID = 26
OIDVECTOR_OID = 30
PATH_OID = 602
PG_LSN_OID = 3220
POINT_OID = 600
POLYGON_OID = 604
RECORD_OID = 2249
REFCURSOR_OID = 1790
REGCLASS_OID = 2205
REGCOLLATION_OID = 4191
REGCONFIG_OID = 3734
REGDICTIONARY_OID = 3769
REGNAMESPACE_OID = 4089
REGOPER_OID = 2203
REGOPERATOR_OID = 2204
REGPROC_OID = 24
REGPROCEDURE_OID = 2202
REGROLE_OID = 4096
REGTYPE_OID = 2206
TEXT_OID = 25
TID_OID = 27
TIME_OID = 1083
TIMESTAMP_OID = 1114
TIMESTAMPTZ_OID = 1184
TIMETZ_OID = 1266
TSMULTIRANGE_OID = 4533
TSQUERY_OID = 3615
TSRANGE_OID = 3908
TSTZMULTIRANGE_OID = 4534
TSTZRANGE_OID = 3910
TSVECTOR_OID = 3614
TXID_SNAPSHOT_OID = 2970
UUID_OID = 2950
VARBIT_OID = 1562
VARCHAR_OID = 1043
XID_OID = 28
XID8_OID = 5069
XML_OID = 142
# autogenerated: end

View File

@@ -0,0 +1,268 @@
"""
commands pipeline management
"""
# Copyright (C) 2021 The Psycopg Team
from __future__ import annotations
import logging
from types import TracebackType
from typing import Any, TYPE_CHECKING
from . import pq
from . import errors as e
from .abc import PipelineCommand, PQGen
from ._compat import Deque, Self, TypeAlias
from .pq.misc import connection_summary
from .generators import pipeline_communicate, fetch_many, send
from ._capabilities import capabilities
if TYPE_CHECKING:
from .pq.abc import PGresult
from .connection import Connection
from ._preparing import Key, Prepare # noqa: F401
from ._cursor_base import BaseCursor # noqa: F401
from ._connection_base import BaseConnection
from .connection_async import AsyncConnection
PendingResult: TypeAlias = (
"tuple[BaseCursor[Any, Any], tuple[Key, Prepare, bytes] | None] | None"
)
FATAL_ERROR = pq.ExecStatus.FATAL_ERROR
PIPELINE_ABORTED = pq.ExecStatus.PIPELINE_ABORTED
BAD = pq.ConnStatus.BAD
ACTIVE = pq.TransactionStatus.ACTIVE
logger = logging.getLogger("psycopg")
class BasePipeline:
command_queue: Deque[PipelineCommand]
result_queue: Deque[PendingResult]
def __init__(self, conn: BaseConnection[Any]) -> None:
self._conn = conn
self.pgconn = conn.pgconn
self.command_queue = Deque[PipelineCommand]()
self.result_queue = Deque[PendingResult]()
self.level = 0
def __repr__(self) -> str:
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
info = connection_summary(self._conn.pgconn)
return f"<{cls} {info} at 0x{id(self):x}>"
@property
def status(self) -> pq.PipelineStatus:
return pq.PipelineStatus(self.pgconn.pipeline_status)
@classmethod
def is_supported(cls) -> bool:
"""Return `!True` if the psycopg libpq wrapper supports pipeline mode."""
return capabilities.has_pipeline()
def _enter_gen(self) -> PQGen[None]:
capabilities.has_pipeline(check=True)
if self.level == 0:
self.pgconn.enter_pipeline_mode()
elif self.command_queue or self.pgconn.transaction_status == ACTIVE:
# Nested pipeline case.
# Transaction might be ACTIVE when the pipeline uses an "implicit
# transaction", typically in autocommit mode. But when entering a
# Psycopg transaction(), we expect the IDLE state. By sync()-ing,
# we make sure all previous commands are completed and the
# transaction gets back to IDLE.
yield from self._sync_gen()
self.level += 1
def _exit(self, exc: BaseException | None) -> None:
self.level -= 1
if self.level == 0 and self.pgconn.status != BAD:
try:
self.pgconn.exit_pipeline_mode()
except e.OperationalError as exc2:
# Notice that this error might be pretty irrecoverable. It
# happens on COPY, for instance: even if sync succeeds, exiting
# fails with "cannot exit pipeline mode with uncollected results"
if exc:
logger.warning("error ignored exiting %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
def _sync_gen(self) -> PQGen[None]:
self._enqueue_sync()
yield from self._communicate_gen()
yield from self._fetch_gen(flush=False)
def _exit_gen(self) -> PQGen[None]:
"""
Exit current pipeline by sending a Sync and fetch back all remaining results.
"""
try:
self._enqueue_sync()
yield from self._communicate_gen()
finally:
yield from self._fetch_gen(flush=True)
def _communicate_gen(self) -> PQGen[None]:
"""Communicate with pipeline to send commands and possibly fetch
results, which are then processed.
"""
fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
exception = None
for results in fetched:
queued = self.result_queue.popleft()
try:
self._process_results(queued, results)
except e.Error as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception
def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
"""Fetch available results from the connection and process them with
pipeline queued items.
If 'flush' is True, a PQsendFlushRequest() is issued in order to make
sure results can be fetched. Otherwise, the caller may emit a
PQpipelineSync() call to ensure the output buffer gets flushed before
fetching.
"""
if not self.result_queue:
return
if flush:
self.pgconn.send_flush_request()
yield from send(self.pgconn)
exception = None
while self.result_queue:
results = yield from fetch_many(self.pgconn)
if not results:
# No more results to fetch, but there may still be pending
# commands.
break
queued = self.result_queue.popleft()
try:
self._process_results(queued, results)
except e.Error as exc:
if exception is None:
exception = exc
if exception is not None:
raise exception
def _process_results(self, queued: PendingResult, results: list[PGresult]) -> None:
"""Process a results set fetched from the current pipeline.
This matches 'results' with its respective element in the pipeline
queue. For commands (None value in the pipeline queue), results are
checked directly. For prepare statement creation requests, update the
cache. Otherwise, results are attached to their respective cursor.
"""
if queued is None:
(result,) = results
if result.status == FATAL_ERROR:
raise e.error_from_result(result, encoding=self.pgconn._encoding)
elif result.status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
else:
cursor, prepinfo = queued
if prepinfo:
key, prep, name = prepinfo
# Update the prepare state of the query.
cursor._conn._prepared.validate(key, prep, name, results)
cursor._check_results(results)
cursor._set_results(results)
def _enqueue_sync(self) -> None:
"""Enqueue a PQpipelineSync() command."""
self.command_queue.append(self.pgconn.pipeline_sync)
self.result_queue.append(None)
class Pipeline(BasePipeline):
"""Handler for connection in pipeline mode."""
__module__ = "psycopg"
_conn: Connection[Any]
def __init__(self, conn: Connection[Any]) -> None:
super().__init__(conn)
def sync(self) -> None:
"""Sync the pipeline, send any pending command and receive and process
all available results.
"""
try:
with self._conn.lock:
self._conn.wait(self._sync_gen())
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def __enter__(self) -> Self:
with self._conn.lock:
self._conn.wait(self._enter_gen())
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
try:
with self._conn.lock:
self._conn.wait(self._exit_gen())
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
logger.warning("error ignored terminating %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
finally:
self._exit(exc_val)
class AsyncPipeline(BasePipeline):
"""Handler for async connection in pipeline mode."""
__module__ = "psycopg"
_conn: AsyncConnection[Any]
def __init__(self, conn: AsyncConnection[Any]) -> None:
super().__init__(conn)
async def sync(self) -> None:
try:
async with self._conn.lock:
await self._conn.wait(self._sync_gen())
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def __aenter__(self) -> Self:
async with self._conn.lock:
await self._conn.wait(self._enter_gen())
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
try:
async with self._conn.lock:
await self._conn.wait(self._exit_gen())
except Exception as exc2:
# Don't clobber an exception raised in the block with this one
if exc_val:
logger.warning("error ignored terminating %r: %s", self, exc2)
else:
raise exc2.with_traceback(None)
finally:
self._exit(exc_val)

View File

@@ -0,0 +1,201 @@
"""
Support for prepared statements
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from enum import IntEnum, auto
from typing import Any, Sequence, TYPE_CHECKING
from collections import OrderedDict
from . import pq
from .abc import PQGen
from ._compat import Deque, TypeAlias
from ._queries import PostgresQuery
if TYPE_CHECKING:
from .pq.abc import PGresult
from ._connection_base import BaseConnection
Key: TypeAlias = "tuple[bytes, tuple[int, ...]]"
COMMAND_OK = pq.ExecStatus.COMMAND_OK
TUPLES_OK = pq.ExecStatus.TUPLES_OK
class Prepare(IntEnum):
NO = auto()
YES = auto()
SHOULD = auto()
class PrepareManager:
# Number of times a query is executed before it is prepared.
prepare_threshold: int | None = 5
# Maximum number of prepared statements on the connection.
prepared_max: int = 100
def __init__(self) -> None:
# Map (query, types) to the number of times the query was seen.
self._counts: OrderedDict[Key, int] = OrderedDict()
# Map (query, types) to the name of the statement if prepared.
self._names: OrderedDict[Key, bytes] = OrderedDict()
# Counter to generate prepared statements names
self._prepared_idx = 0
self._to_flush = Deque["bytes | None"]()
@staticmethod
def key(query: PostgresQuery) -> Key:
return (query.query, query.types)
def get(
self, query: PostgresQuery, prepare: bool | None = None
) -> tuple[Prepare, bytes]:
"""
Check if a query is prepared, tell back whether to prepare it.
"""
if prepare is False or self.prepare_threshold is None:
# The user doesn't want this query to be prepared
return Prepare.NO, b""
key = self.key(query)
name = self._names.get(key)
if name:
# The query was already prepared in this session
return Prepare.YES, name
count = self._counts.get(key, 0)
if count >= self.prepare_threshold or prepare:
# The query has been executed enough times and needs to be prepared
name = f"_pg3_{self._prepared_idx}".encode()
self._prepared_idx += 1
return Prepare.SHOULD, name
else:
# The query is not to be prepared yet
return Prepare.NO, b""
def _should_discard(self, prep: Prepare, results: Sequence[PGresult]) -> bool:
"""Check if we need to discard our entire state: it should happen on
rollback or on dropping objects, because the same object may get
recreated and postgres would fail internal lookups.
"""
if self._names or prep == Prepare.SHOULD:
for result in results:
if result.status != COMMAND_OK:
continue
cmdstat = result.command_status
if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
return self.clear()
return False
@staticmethod
def _check_results(results: Sequence[PGresult]) -> bool:
"""Return False if 'results' are invalid for prepared statement cache."""
if len(results) != 1:
# We cannot prepare a multiple statement
return False
status = results[0].status
if COMMAND_OK != status != TUPLES_OK:
# We don't prepare failed queries or other weird results
return False
return True
def _rotate(self) -> None:
"""Evict an old value from the cache.
If it was prepared, deallocate it. Do it only once: if the cache was
resized, deallocate gradually.
"""
if len(self._counts) > self.prepared_max:
self._counts.popitem(last=False)
if len(self._names) > self.prepared_max:
name = self._names.popitem(last=False)[1]
self._to_flush.append(name)
def maybe_add_to_cache(
self, query: PostgresQuery, prep: Prepare, name: bytes
) -> Key | None:
"""Handle 'query' for possible addition to the cache.
If a new entry has been added, return its key. Return None otherwise
(meaning the query is already in cache or cache is not enabled).
"""
# don't do anything if prepared statements are disabled
if self.prepare_threshold is None:
return None
key = self.key(query)
if key in self._counts:
if prep is Prepare.SHOULD:
del self._counts[key]
self._names[key] = name
else:
self._counts[key] += 1
self._counts.move_to_end(key)
return None
elif key in self._names:
self._names.move_to_end(key)
return None
else:
if prep is Prepare.SHOULD:
self._names[key] = name
else:
self._counts[key] = 1
return key
def validate(
self,
key: Key,
prep: Prepare,
name: bytes,
results: Sequence[PGresult],
) -> None:
"""Validate cached entry with 'key' by checking query 'results'.
Possibly record a command to perform maintenance on database side.
"""
if self._should_discard(prep, results):
return
if not self._check_results(results):
self._names.pop(key, None)
self._counts.pop(key, None)
else:
self._rotate()
def clear(self) -> bool:
"""Clear the cache of the maintenance commands.
Clear the internal state and prepare a command to clear the state of
the server.
"""
self._counts.clear()
if self._names:
self._names.clear()
self._to_flush.clear()
self._to_flush.append(None)
return True
else:
return False
def maintain_gen(self, conn: BaseConnection[Any]) -> PQGen[None]:
"""
Generator to send the commands to perform periodic maintenance
Deallocate unneeded command in the server, or flush the prepared
statements server state entirely if necessary.
"""
while self._to_flush:
name = self._to_flush.popleft()
yield from conn._deallocate(name)

View File

@@ -0,0 +1,361 @@
"""
Helper object to transform values between Python and PostgreSQL
Python implementation of the object. Use the `_transformer module to import
the right implementation (Python or C). The public place where the object
is exported is `psycopg.adapt` (which we may not use to avoid circular
dependencies problems).
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from typing import Any, Sequence, DefaultDict, TYPE_CHECKING
from collections import defaultdict
from . import pq
from . import abc
from . import errors as e
from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, NoneType
from .rows import Row, RowMaker
from ._oids import INVALID_OID, TEXT_OID
from ._compat import TypeAlias
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .abc import DumperKey # noqa: F401
from .adapt import AdaptersMap
from .pq.abc import PGresult
from ._connection_base import BaseConnection
DumperCache: TypeAlias = "dict[DumperKey, abc.Dumper]"
OidDumperCache: TypeAlias = "dict[int, abc.Dumper]"
LoaderCache: TypeAlias = "dict[int, abc.Loader]"
TEXT = pq.Format.TEXT
PY_TEXT = PyFormat.TEXT
class Transformer(AdaptContext):
"""
An object that can adapt efficiently between Python and PostgreSQL.
The life cycle of the object is the query, so it is assumed that attributes
such as the server version or the connection encoding will not change. The
object have its state so adapting several values of the same type can be
optimised.
"""
__module__ = "psycopg.adapt"
__slots__ = """
types formats
_conn _adapters _pgresult _dumpers _loaders _encoding _none_oid
_oid_dumpers _oid_types _row_dumpers _row_loaders
""".split()
types: tuple[int, ...] | None
formats: list[pq.Format] | None
_adapters: AdaptersMap
_pgresult: PGresult | None
_none_oid: int
def __init__(self, context: AdaptContext | None = None):
self._pgresult = self.types = self.formats = None
# WARNING: don't store context, or you'll create a loop with the Cursor
if context:
self._adapters = context.adapters
self._conn = context.connection
else:
from . import postgres
self._adapters = postgres.adapters
self._conn = None
# mapping fmt, class -> Dumper instance
self._dumpers: DefaultDict[PyFormat, DumperCache]
self._dumpers = defaultdict(dict)
# mapping fmt, oid -> Dumper instance
# Not often used, so create it only if needed.
self._oid_dumpers: tuple[OidDumperCache, OidDumperCache] | None
self._oid_dumpers = None
# mapping fmt, oid -> Loader instance
self._loaders: tuple[LoaderCache, LoaderCache] = ({}, {})
self._row_dumpers: list[abc.Dumper] | None = None
# sequence of load functions from value to python
# the length of the result columns
self._row_loaders: list[LoadFunc] = []
# mapping oid -> type sql representation
self._oid_types: dict[int, bytes] = {}
self._encoding = ""
@classmethod
def from_context(cls, context: AdaptContext | None) -> Transformer:
"""
Return a Transformer from an AdaptContext.
If the context is a Transformer instance, just return it.
"""
if isinstance(context, Transformer):
return context
else:
return cls(context)
@property
def connection(self) -> BaseConnection[Any] | None:
return self._conn
@property
def encoding(self) -> str:
if not self._encoding:
self._encoding = conn_encoding(self.connection)
return self._encoding
@property
def adapters(self) -> AdaptersMap:
return self._adapters
@property
def pgresult(self) -> PGresult | None:
return self._pgresult
def set_pgresult(
self,
result: PGresult | None,
*,
set_loaders: bool = True,
format: pq.Format | None = None,
) -> None:
self._pgresult = result
if not result:
self._nfields = self._ntuples = 0
if set_loaders:
self._row_loaders = []
return
self._ntuples = result.ntuples
nf = self._nfields = result.nfields
if not set_loaders:
return
if not nf:
self._row_loaders = []
return
fmt: pq.Format
fmt = result.fformat(0) if format is None else format # type: ignore
self._row_loaders = [
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
]
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
self.types = tuple(types)
self.formats = [format] * len(types)
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
self._row_loaders = [self.get_loader(oid, format).load for oid in types]
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Buffer | None]:
nparams = len(params)
out: list[Buffer | None] = [None] * nparams
# If we have dumpers, it means set_dumper_types had been called, in
# which case self.types and self.formats are set to sequences of the
# right size.
if self._row_dumpers:
for i in range(nparams):
param = params[i]
if param is not None:
out[i] = self._row_dumpers[i].dump(param)
return out
types = [self._get_none_oid()] * nparams
pqformats = [TEXT] * nparams
for i in range(nparams):
param = params[i]
if param is None:
continue
dumper = self.get_dumper(param, formats[i])
out[i] = dumper.dump(param)
types[i] = dumper.oid
pqformats[i] = dumper.format
self.types = tuple(types)
self.formats = pqformats
return out
def as_literal(self, obj: Any) -> bytes:
dumper = self.get_dumper(obj, PY_TEXT)
rv = dumper.quote(obj)
# If the result is quoted, and the oid not unknown or text,
# add an explicit type cast.
# Check the last char because the first one might be 'E'.
oid = dumper.oid
if oid and rv and rv[-1] == b"'"[0] and oid != TEXT_OID:
try:
type_sql = self._oid_types[oid]
except KeyError:
ti = self.adapters.types.get(oid)
if ti:
if oid < 8192:
# builtin: prefer "timestamptz" to "timestamp with time zone"
type_sql = ti.name.encode(self.encoding)
else:
type_sql = ti.regtype.encode(self.encoding)
if oid == ti.array_oid:
type_sql += b"[]"
else:
type_sql = b""
self._oid_types[oid] = type_sql
if type_sql:
rv = b"%s::%s" % (rv, type_sql)
if not isinstance(rv, bytes):
rv = bytes(rv)
return rv
def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper:
"""
Return a Dumper instance to dump `!obj`.
"""
# Normally, the type of the object dictates how to dump it
key = type(obj)
# Reuse an existing Dumper class for objects of the same type
cache = self._dumpers[format]
try:
dumper = cache[key]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
try:
dcls = self.adapters.get_dumper(key, format)
except e.ProgrammingError as ex:
raise ex from None
else:
cache[key] = dumper = dcls(key, self)
# Check if the dumper requires an upgrade to handle this specific value
key1 = dumper.get_key(obj, format)
if key1 is key:
return dumper
# If it does, ask the dumper to create its own upgraded version
try:
return cache[key1]
except KeyError:
dumper = cache[key1] = dumper.upgrade(obj, format)
return dumper
def _get_none_oid(self) -> int:
try:
return self._none_oid
except AttributeError:
pass
try:
rv = self._none_oid = self._adapters.get_dumper(NoneType, PY_TEXT).oid
except KeyError:
raise e.InterfaceError("None dumper not found")
return rv
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper:
"""
Return a Dumper to dump an object to the type with given oid.
"""
if not self._oid_dumpers:
self._oid_dumpers = ({}, {})
# Reuse an existing Dumper class for objects of the same type
cache = self._oid_dumpers[format]
try:
return cache[oid]
except KeyError:
# If it's the first time we see this type, look for a dumper
# configured for it.
dcls = self.adapters.get_dumper_by_oid(oid, format)
cache[oid] = dumper = dcls(NoneType, self)
return dumper
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]:
res = self._pgresult
if not res:
raise e.InterfaceError("result not set")
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
raise e.InterfaceError(
f"rows must be included between 0 and {self._ntuples}"
)
records = []
for row in range(row0, row1):
record: list[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
records.append(make_row(record))
return records
def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None:
res = self._pgresult
if not res:
return None
if not 0 <= row < self._ntuples:
return None
record: list[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
return make_row(record)
def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]:
if len(self._row_loaders) != len(record):
raise e.ProgrammingError(
f"cannot load sequence of {len(record)} items:"
f" {len(self._row_loaders)} loaders registered"
)
return tuple(
(self._row_loaders[i](val) if val is not None else None)
for i, val in enumerate(record)
)
def get_loader(self, oid: int, format: pq.Format) -> abc.Loader:
try:
return self._loaders[format][oid]
except KeyError:
pass
loader_cls = self._adapters.get_loader(oid, format)
if not loader_cls:
loader_cls = self._adapters.get_loader(INVALID_OID, format)
if not loader_cls:
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders[format][oid] = loader_cls(oid, self)
return loader

View File

@@ -0,0 +1,425 @@
"""
Utility module to manipulate queries
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import re
from typing import Any, Callable, Mapping, Match, NamedTuple
from typing import Sequence, TYPE_CHECKING
from functools import lru_cache
from . import pq
from . import errors as e
from .sql import Composable
from .abc import Buffer, Query, Params
from ._enums import PyFormat
from ._compat import TypeAlias, TypeGuard
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .abc import Transformer
MAX_CACHED_STATEMENT_LENGTH = 4096
MAX_CACHED_STATEMENT_PARAMS = 50
class QueryPart(NamedTuple):
pre: bytes
item: int | str
format: PyFormat
class PostgresQuery:
"""
Helper to convert a Python query and parameters into Postgres format.
"""
__slots__ = """
query params types formats
_tx _want_formats _parts _encoding _order
""".split()
def __init__(self, transformer: Transformer):
self._tx = transformer
self.params: Sequence[Buffer | None] | None = None
# these are tuples so they can be used as keys e.g. in prepared stmts
self.types: tuple[int, ...] = ()
# The format requested by the user and the ones to really pass Postgres
self._want_formats: list[PyFormat] | None = None
self.formats: Sequence[pq.Format] | None = None
self._encoding = conn_encoding(transformer.connection)
self._parts: list[QueryPart]
self.query = b""
self._order: list[str] | None = None
def convert(self, query: Query, vars: Params | None) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
# Avoid caching queries extremely long or with a huge number of
# parameters. They are usually generated by ORMs and have poor
# cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
# numbers of tuples.
# see https://github.com/psycopg/psycopg/discussions/628
if (
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
):
f: _Query2Pg = _query2pg
else:
f = _query2pg_nocache
(self.query, self._want_formats, self._order, self._parts) = f(
bquery, self._encoding
)
else:
self.query = bquery
self._want_formats = self._order = None
self.dump(vars)
def dump(self, vars: Params | None) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = self.validate_and_reorder_params(self._parts, vars, self._order)
assert self._want_formats is not None
self.params = self._tx.dump_sequence(params, self._want_formats)
self.types = self._tx.types or ()
self.formats = self._tx.formats
else:
self.params = None
self.types = ()
self.formats = None
@staticmethod
def is_params_sequence(vars: Params) -> TypeGuard[Sequence[Any]]:
# Try concrete types, then abstract types
t = type(vars)
if t is list or t is tuple:
sequence = True
elif t is dict:
sequence = False
elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
sequence = True
elif isinstance(vars, Mapping):
sequence = False
else:
raise TypeError(
"query parameters should be a sequence or a mapping,"
f" got {type(vars).__name__}"
)
return sequence
@staticmethod
def validate_and_reorder_params(
parts: list[QueryPart], vars: Params, order: list[str] | None
) -> Sequence[Any]:
"""
Verify the compatibility between a query and a set of params.
"""
if PostgresQuery.is_params_sequence(vars):
if len(vars) != len(parts) - 1:
raise e.ProgrammingError(
f"the query has {len(parts) - 1} placeholders but"
f" {len(vars)} parameters were passed"
)
if vars and not isinstance(parts[0].item, int):
raise TypeError("named placeholders require a mapping of parameters")
return vars
else:
if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
raise TypeError(
"positional placeholders (%s) require a sequence of parameters"
)
try:
if order:
return [vars[item] for item in order] # type: ignore[call-overload]
else:
return ()
except KeyError:
raise e.ProgrammingError(
"query parameter missing:"
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
)
# The type of the _query2pg() and _query2pg_nocache() methods
_Query2Pg: TypeAlias = Callable[
[bytes, str], "tuple[bytes, list[PyFormat], list[str] | None, list[QueryPart]]"
]
def _query2pg_nocache(
query: bytes, encoding: str
) -> tuple[bytes, list[PyFormat], list[str] | None, list[QueryPart]]:
"""
Convert Python query and params into something Postgres understands.
- Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
format (``$1``, ``$2``)
- placeholders can be %s, %t, or %b (auto, text or binary)
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
(sequence of names used in the query, in the position they appear)
``parts`` (splits of queries and placeholders).
"""
parts = _split_query(query, encoding)
order: list[str] | None = None
chunks: list[bytes] = []
formats = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"$%d" % (part.item + 1))
formats.append(part.format)
elif isinstance(parts[0].item, str):
seen: dict[str, tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"$%d" % (len(seen) + 1)
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
formats.append(part.format)
else:
if seen[part.item][1] != part.format:
raise e.ProgrammingError(
f"placeholder '{part.item}' cannot have different formats"
)
chunks.append(seen[part.item][0])
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), formats, order, parts
# Note: the cache size is 128 items, but someone has reported throwing ~12k
# queries (of type `INSERT ... VALUES (...), (...)` with a varying amount of
# records), and the resulting cache size is >100Mb. So, we will avoid to cache
# large queries or queries with a large number of params. See
# https://github.com/sqlalchemy/sqlalchemy/discussions/10270
_query2pg = lru_cache(_query2pg_nocache)
class PostgresClientQuery(PostgresQuery):
"""
PostgresQuery subclass merging query and arguments client-side.
"""
__slots__ = ("template",)
def convert(self, query: Query, vars: Params | None) -> None:
"""
Set up the query and parameters to convert.
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
if isinstance(query, str):
bquery = query.encode(self._encoding)
elif isinstance(query, Composable):
bquery = query.as_bytes(self._tx)
else:
bquery = query
if vars is not None:
if (
len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
):
f: _Query2PgClient = _query2pg_client
else:
f = _query2pg_client_nocache
(self.template, self._order, self._parts) = f(bquery, self._encoding)
else:
self.query = bquery
self._order = None
self.dump(vars)
def dump(self, vars: Params | None) -> None:
"""
Process a new set of variables on the query processed by `convert()`.
This method updates `params` and `types`.
"""
if vars is not None:
params = self.validate_and_reorder_params(self._parts, vars, self._order)
self.params = tuple(
self._tx.as_literal(p) if p is not None else b"NULL" for p in params
)
self.query = self.template % self.params
else:
self.params = None
_Query2PgClient: TypeAlias = Callable[
[bytes, str], "tuple[bytes, list[str] | None, list[QueryPart]]"
]
def _query2pg_client_nocache(
query: bytes, encoding: str
) -> tuple[bytes, list[str] | None, list[QueryPart]]:
"""
Convert Python query and params into a template to perform client-side binding
"""
parts = _split_query(query, encoding, collapse_double_percent=False)
order: list[str] | None = None
chunks: list[bytes] = []
if isinstance(parts[0].item, int):
for part in parts[:-1]:
assert isinstance(part.item, int)
chunks.append(part.pre)
chunks.append(b"%s")
elif isinstance(parts[0].item, str):
seen: dict[str, tuple[bytes, PyFormat]] = {}
order = []
for part in parts[:-1]:
assert isinstance(part.item, str)
chunks.append(part.pre)
if part.item not in seen:
ph = b"%s"
seen[part.item] = (ph, part.format)
order.append(part.item)
chunks.append(ph)
else:
chunks.append(seen[part.item][0])
order.append(part.item)
# last part
chunks.append(parts[-1].pre)
return b"".join(chunks), order, parts
_query2pg_client = lru_cache(_query2pg_client_nocache)
_re_placeholder = re.compile(
rb"""(?x)
% # a literal %
(?:
(?:
\( ([^)]+) \) # or a name in (braces)
. # followed by a format
)
|
(?:.) # or any char, really
)
"""
)
def _split_query(
query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
) -> list[QueryPart]:
parts: list[tuple[bytes, Match[bytes] | None]] = []
cur = 0
# pairs [(fragment, match], with the last match None
m = None
for m in _re_placeholder.finditer(query):
pre = query[cur : m.span(0)[0]]
parts.append((pre, m))
cur = m.span(0)[1]
if m:
parts.append((query[cur:], None))
else:
parts.append((query, None))
rv = []
# drop the "%%", validate
i = 0
phtype = None
while i < len(parts):
pre, m = parts[i]
if m is None:
# last part
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
break
ph = m.group(0)
if ph == b"%%":
# unescape '%%' to '%' if necessary, then merge the parts
if collapse_double_percent:
ph = b"%"
pre1, m1 = parts[i + 1]
parts[i + 1] = (pre + ph + pre1, m1)
del parts[i]
continue
if ph == b"%(":
raise e.ProgrammingError(
"incomplete placeholder:"
f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
)
elif ph == b"% ":
# explicit message for a typical error
raise e.ProgrammingError(
"incomplete placeholder: '%'; if you want to use '%' as an"
" operator you can double it up, i.e. use '%%'"
)
elif ph[-1:] not in b"sbt":
raise e.ProgrammingError(
"only '%s', '%b', '%t' are allowed as placeholders, got"
f" '{m.group(0).decode(encoding)}'"
)
# Index or name
item: int | str
item = m.group(1).decode(encoding) if m.group(1) else i
if not phtype:
phtype = type(item)
elif phtype is not type(item):
raise e.ProgrammingError(
"positional and named placeholders cannot be mixed"
)
format = _ph_to_fmt[ph[-1:]]
rv.append(QueryPart(pre, item, format))
i += 1
return rv
_ph_to_fmt = {
b"s": PyFormat.AUTO,
b"t": PyFormat.TEXT,
b"b": PyFormat.BINARY,
}

View File

@@ -0,0 +1,57 @@
"""
Utility functions to deal with binary structs.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import struct
from typing import Callable, cast, Protocol
from . import errors as e
from .abc import Buffer
from ._compat import TypeAlias
PackInt: TypeAlias = Callable[[int], bytes]
UnpackInt: TypeAlias = Callable[[Buffer], "tuple[int]"]
PackFloat: TypeAlias = Callable[[float], bytes]
UnpackFloat: TypeAlias = Callable[[Buffer], "tuple[float]"]
class UnpackLen(Protocol):
def __call__(self, data: Buffer, start: int | None) -> tuple[int]: ...
pack_int2 = cast(PackInt, struct.Struct("!h").pack)
pack_uint2 = cast(PackInt, struct.Struct("!H").pack)
pack_int4 = cast(PackInt, struct.Struct("!i").pack)
pack_uint4 = cast(PackInt, struct.Struct("!I").pack)
pack_int8 = cast(PackInt, struct.Struct("!q").pack)
pack_float4 = cast(PackFloat, struct.Struct("!f").pack)
pack_float8 = cast(PackFloat, struct.Struct("!d").pack)
unpack_int2 = cast(UnpackInt, struct.Struct("!h").unpack)
unpack_uint2 = cast(UnpackInt, struct.Struct("!H").unpack)
unpack_int4 = cast(UnpackInt, struct.Struct("!i").unpack)
unpack_uint4 = cast(UnpackInt, struct.Struct("!I").unpack)
unpack_int8 = cast(UnpackInt, struct.Struct("!q").unpack)
unpack_float4 = cast(UnpackFloat, struct.Struct("!f").unpack)
unpack_float8 = cast(UnpackFloat, struct.Struct("!d").unpack)
_struct_len = struct.Struct("!i")
pack_len = cast(Callable[[int], bytes], _struct_len.pack)
unpack_len = cast(UnpackLen, _struct_len.unpack_from)
def pack_float4_bug_304(x: float) -> bytes:
raise e.InterfaceError(
"cannot dump Float4: Python affected by bug #304. Note that the psycopg-c"
" and psycopg-binary packages are not affected by this issue."
" See https://github.com/psycopg/psycopg/issues/304"
)
# If issue #304 is detected, raise an error instead of dumping wrong data.
if struct.Struct("!f").pack(1.0) != bytes.fromhex("3f800000"):
pack_float4 = pack_float4_bug_304

View File

@@ -0,0 +1,115 @@
"""
psycopg two-phase commit support
"""
# Copyright (C) 2021 The Psycopg Team
from __future__ import annotations
import re
import datetime as dt
from base64 import b64encode, b64decode
from dataclasses import dataclass, replace
_re_xid = re.compile(r"^(\d+)_([^_]*)_([^_]*)$")
@dataclass(frozen=True)
class Xid:
"""A two-phase commit transaction identifier.
The object can also be unpacked as a 3-item tuple (`format_id`, `gtrid`,
`bqual`).
"""
format_id: int | None
gtrid: str
bqual: str | None
prepared: dt.datetime | None = None
owner: str | None = None
database: str | None = None
@classmethod
def from_string(cls, s: str) -> Xid:
"""Try to parse an XA triple from the string.
This may fail for several reasons. In such case return an unparsed Xid.
"""
try:
return cls._parse_string(s)
except Exception:
return Xid(None, s, None)
def __str__(self) -> str:
return self._as_tid()
def __len__(self) -> int:
return 3
def __getitem__(self, index: int) -> int | str | None:
return (self.format_id, self.gtrid, self.bqual)[index]
@classmethod
def _parse_string(cls, s: str) -> Xid:
m = _re_xid.match(s)
if not m:
raise ValueError("bad Xid format")
format_id = int(m.group(1))
gtrid = b64decode(m.group(2)).decode()
bqual = b64decode(m.group(3)).decode()
return cls.from_parts(format_id, gtrid, bqual)
@classmethod
def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> Xid:
if format_id is not None:
if bqual is None:
raise TypeError("if format_id is specified, bqual must be too")
if not 0 <= format_id < 0x80000000:
raise ValueError("format_id must be a non-negative 32-bit integer")
if len(bqual) > 64:
raise ValueError("bqual must be not longer than 64 chars")
if len(gtrid) > 64:
raise ValueError("gtrid must be not longer than 64 chars")
elif bqual is None:
raise TypeError("if format_id is None, bqual must be None too")
return Xid(format_id, gtrid, bqual)
def _as_tid(self) -> str:
"""
Return the PostgreSQL transaction_id for this XA xid.
PostgreSQL wants just a string, while the DBAPI supports the XA
standard and thus a triple. We use the same conversion algorithm
implemented by JDBC in order to allow some form of interoperation.
see also: the pgjdbc implementation
http://cvs.pgfoundry.org/cgi-bin/cvsweb.cgi/jdbc/pgjdbc/org/
postgresql/xa/RecoveredXid.java?rev=1.2
"""
if self.format_id is None or self.bqual is None:
# Unparsed xid: return the gtrid.
return self.gtrid
# XA xid: mash together the components.
egtrid = b64encode(self.gtrid.encode()).decode()
ebqual = b64encode(self.bqual.encode()).decode()
return f"{self.format_id}_{egtrid}_{ebqual}"
@classmethod
def _get_recover_query(cls) -> str:
return "SELECT gid, prepared, owner, database FROM pg_prepared_xacts"
@classmethod
def _from_record(
cls, gid: str, prepared: dt.datetime, owner: str, database: str
) -> Xid:
xid = Xid.from_string(gid)
return replace(xid, prepared=prepared, owner=owner, database=database)
Xid.__module__ = "psycopg"

View File

@@ -0,0 +1,21 @@
"""
Helper object to transform values between Python and PostgreSQL
This module exports the requested implementation to the rest of the package.
"""
# Copyright (C) 2023 The Psycopg Team
from __future__ import annotations
from . import abc
from ._cmodule import _psycopg
Transformer: type[abc.Transformer]
if _psycopg:
Transformer = _psycopg.Transformer
else:
from . import _py_transformer
Transformer = _py_transformer.Transformer

View File

@@ -0,0 +1,343 @@
"""
Information about PostgreSQL types
These types allow to read information from the system catalog and provide
information to the adapters if needed.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from typing import Any, Iterator, overload, Sequence, TYPE_CHECKING
from . import sql
from . import errors as e
from .abc import AdaptContext, Query
from .rows import dict_row
from ._compat import TypeAlias, TypeVar
from ._typemod import TypeModifier
from ._encodings import conn_encoding
if TYPE_CHECKING:
from .connection import Connection
from .connection_async import AsyncConnection
from ._connection_base import BaseConnection
T = TypeVar("T", bound="TypeInfo")
RegistryKey: TypeAlias = "str | int | tuple[type, int]"
class TypeInfo:
"""
Hold information about a PostgreSQL base type.
"""
__module__ = "psycopg.types"
def __init__(
self,
name: str,
oid: int,
array_oid: int,
*,
regtype: str = "",
delimiter: str = ",",
typemod: type[TypeModifier] = TypeModifier,
):
self.name = name
self.oid = oid
self.array_oid = array_oid
self.regtype = regtype or name
self.delimiter = delimiter
self.typemod = typemod(oid)
def __repr__(self) -> str:
return (
f"<{self.__class__.__qualname__}:"
f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
)
@overload
@classmethod
def fetch(
cls: type[T], conn: Connection[Any], name: str | sql.Identifier
) -> T | None: ...
@overload
@classmethod
async def fetch(
cls: type[T], conn: AsyncConnection[Any], name: str | sql.Identifier
) -> T | None: ...
@classmethod
def fetch(
cls: type[T], conn: BaseConnection[Any], name: str | sql.Identifier
) -> Any:
"""Query a system catalog to read information about a type."""
from .connection import Connection
from .connection_async import AsyncConnection
if isinstance(name, sql.Composable):
name = name.as_string(conn)
if isinstance(conn, Connection):
return cls._fetch(conn, name)
elif isinstance(conn, AsyncConnection):
return cls._fetch_async(conn, name)
else:
raise TypeError(
f"expected Connection or AsyncConnection, got {type(conn).__name__}"
)
@classmethod
def _fetch(cls: type[T], conn: Connection[Any], name: str) -> T | None:
# This might result in a nested transaction. What we want is to leave
# the function with the connection in the state we found (either idle
# or intrans)
try:
from psycopg import Cursor
with conn.transaction(), Cursor(conn, row_factory=dict_row) as cur:
if conn_encoding(conn) == "ascii":
cur.execute("set local client_encoding to utf8")
cur.execute(cls._get_info_query(conn), {"name": name})
recs = cur.fetchall()
except e.UndefinedObject:
return None
return cls._from_records(name, recs)
@classmethod
async def _fetch_async(
cls: type[T], conn: AsyncConnection[Any], name: str
) -> T | None:
try:
from psycopg import AsyncCursor
async with conn.transaction():
async with AsyncCursor(conn, row_factory=dict_row) as cur:
if conn_encoding(conn) == "ascii":
await cur.execute("set local client_encoding to utf8")
await cur.execute(cls._get_info_query(conn), {"name": name})
recs = await cur.fetchall()
except e.UndefinedObject:
return None
return cls._from_records(name, recs)
@classmethod
def _from_records(
cls: type[T], name: str, recs: Sequence[dict[str, Any]]
) -> T | None:
if len(recs) == 1:
return cls(**recs[0])
elif not recs:
return None
else:
raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
def register(self, context: AdaptContext | None = None) -> None:
"""
Register the type information, globally or in the specified `!context`.
"""
if context:
types = context.adapters.types
else:
from . import postgres
types = postgres.types
types.add(self)
if self.array_oid:
from .types.array import register_array
register_array(self, context)
@classmethod
def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
return sql.SQL(
"""\
SELECT
typname AS name, oid, typarray AS array_oid,
oid::regtype::text AS regtype, typdelim AS delimiter
FROM pg_type t
WHERE t.oid = {regtype}
ORDER BY t.oid
"""
).format(regtype=cls._to_regtype(conn))
@classmethod
def _has_to_regtype_function(cls, conn: BaseConnection[Any]) -> bool:
# to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
info = conn.info
if info.vendor == "PostgreSQL":
return info.server_version >= 90400
elif info.vendor == "CockroachDB":
return info.server_version >= 220200
else:
return False
@classmethod
def _to_regtype(cls, conn: BaseConnection[Any]) -> sql.SQL:
# `to_regtype()` returns the type oid or NULL, unlike the :: operator,
# which returns the type or raises an exception, which requires
# a transaction rollback and leaves traces in the server logs.
if cls._has_to_regtype_function(conn):
return sql.SQL("to_regtype(%(name)s)")
else:
return sql.SQL("%(name)s::regtype")
def _added(self, registry: TypesRegistry) -> None:
"""Method called by the `!registry` when the object is added there."""
pass
def get_type_display(self, oid: int | None = None, fmod: int | None = None) -> str:
parts = []
parts.append(self.name)
mod = self.typemod.get_modifier(fmod) if fmod is not None else ()
if mod:
parts.append(f"({','.join(map(str, mod))})")
if oid == self.array_oid:
parts.append("[]")
return "".join(parts)
def get_display_size(self, fmod: int) -> int | None:
return self.typemod.get_display_size(fmod)
def get_precision(self, fmod: int) -> int | None:
return self.typemod.get_precision(fmod)
def get_scale(self, fmod: int) -> int | None:
return self.typemod.get_scale(fmod)
class TypesRegistry:
"""
Container for the information about types in a database.
"""
__module__ = "psycopg.types"
def __init__(self, template: TypesRegistry | None = None):
self._registry: dict[RegistryKey, TypeInfo]
# Make a shallow copy: it will become a proper copy if the registry
# is edited.
if template:
self._registry = template._registry
self._own_state = False
template._own_state = False
else:
self.clear()
def clear(self) -> None:
self._registry = {}
self._own_state = True
def add(self, info: TypeInfo) -> None:
self._ensure_own_state()
if info.oid:
self._registry[info.oid] = info
if info.array_oid:
self._registry[info.array_oid] = info
self._registry[info.name] = info
if info.regtype and info.regtype not in self._registry:
self._registry[info.regtype] = info
# Allow info to customise further their relation with the registry
info._added(self)
def __iter__(self) -> Iterator[TypeInfo]:
seen = set()
for t in self._registry.values():
if id(t) not in seen:
seen.add(id(t))
yield t
@overload
def __getitem__(self, key: str | int) -> TypeInfo: ...
@overload
def __getitem__(self, key: tuple[type[T], int]) -> T: ...
def __getitem__(self, key: RegistryKey) -> TypeInfo:
"""
Return info about a type, specified by name or oid
:param key: the name or oid of the type to look for.
Raise KeyError if not found.
"""
if isinstance(key, str):
if key.endswith("[]"):
key = key[:-2]
elif not isinstance(key, (int, tuple)):
raise TypeError(f"the key must be an oid or a name, got {type(key)}")
try:
return self._registry[key]
except KeyError:
raise KeyError(f"couldn't find the type {key!r} in the types registry")
@overload
def get(self, key: str | int) -> TypeInfo | None: ...
@overload
def get(self, key: tuple[type[T], int]) -> T | None: ...
def get(self, key: RegistryKey) -> TypeInfo | None:
"""
Return info about a type, specified by name or oid
:param key: the name or oid of the type to look for.
Unlike `__getitem__`, return None if not found.
"""
try:
return self[key]
except KeyError:
return None
def get_oid(self, name: str) -> int:
"""
Return the oid of a PostgreSQL type by name.
:param key: the name of the type to look for.
Return the array oid if the type ends with "``[]``"
Raise KeyError if the name is unknown.
"""
t = self[name]
if name.endswith("[]"):
return t.array_oid
else:
return t.oid
def get_by_subtype(self, cls: type[T], subtype: int | str) -> T | None:
"""
Return info about a `TypeInfo` subclass by its element name or oid.
:param cls: the subtype of `!TypeInfo` to look for. Currently
supported are `~psycopg.types.range.RangeInfo` and
`~psycopg.types.multirange.MultirangeInfo`.
:param subtype: The name or OID of the subtype of the element to look for.
:return: The `!TypeInfo` object of class `!cls` whose subtype is
`!subtype`. `!None` if the element or its range are not found.
"""
try:
info = self[subtype]
except KeyError:
return None
return self.get((cls, info.oid))
def _ensure_own_state(self) -> None:
# Time to write! so, copy.
if not self._own_state:
self._registry = self._registry.copy()
self._own_state = True

View File

@@ -0,0 +1,86 @@
"""
PostgreSQL type modifiers.
The type modifiers parse catalog information to obtain the type modifier
of a column - the numeric part of varchar(10) or decimal(6,2).
"""
# Copyright (C) 2024 The Psycopg Team
from __future__ import annotations
class TypeModifier:
"""Type modifier that doesn't know any modifier.
Useful to describe types with no type modifier.
"""
def __init__(self, oid: int):
self.oid = oid
def get_modifier(self, typemod: int) -> tuple[int, ...] | None:
return None
def get_display_size(self, typemod: int) -> int | None:
return None
def get_precision(self, typemod: int) -> int | None:
return None
def get_scale(self, typemod: int) -> int | None:
return None
class NumericTypeModifier(TypeModifier):
"""Handle numeric type modifier."""
def get_modifier(self, typemod: int) -> tuple[int, ...] | None:
precision = self.get_precision(typemod)
scale = self.get_scale(typemod)
return None if precision is None or scale is None else (precision, scale)
def get_precision(self, typemod: int) -> int | None:
return typemod >> 16 if typemod >= 0 else None
def get_scale(self, typemod: int) -> int | None:
if typemod < 0:
return None
scale = (typemod - 4) & 0xFFFF
if scale >= 0x400:
scale = scale - 0x800
return scale
class CharTypeModifier(TypeModifier):
"""Handle char/varchar type modifier."""
def get_modifier(self, typemod: int) -> tuple[int, ...] | None:
dsize = self.get_display_size(typemod)
return (dsize,) if dsize else None
def get_display_size(self, typemod: int) -> int | None:
return typemod - 4 if typemod >= 0 else None
class BitTypeModifier(TypeModifier):
"""Handle bit/varbit type modifier."""
def get_modifier(self, typemod: int) -> tuple[int, ...] | None:
dsize = self.get_display_size(typemod)
return (dsize,) if dsize else None
def get_display_size(self, typemod: int) -> int | None:
return typemod if typemod >= 0 else None
class TimeTypeModifier(TypeModifier):
"""Handle time-related types modifier."""
def get_modifier(self, typemod: int) -> tuple[int, ...] | None:
prec = self.get_precision(typemod)
return (prec,) if prec is not None else None
def get_precision(self, typemod: int) -> int | None:
return typemod & 0xFFFF if typemod >= 0 else None

View File

@@ -0,0 +1,45 @@
"""
Timezone utility functions.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import logging
from datetime import timezone, tzinfo
from .pq.abc import PGconn
from ._compat import ZoneInfo
logger = logging.getLogger("psycopg")
_timezones: dict[bytes | None, tzinfo] = {
None: timezone.utc,
b"UTC": timezone.utc,
}
def get_tzinfo(pgconn: PGconn | None) -> tzinfo:
"""Return the Python timezone info of the connection's timezone."""
tzname = pgconn.parameter_status(b"TimeZone") if pgconn else None
try:
return _timezones[tzname]
except KeyError:
sname = tzname.decode() if tzname else "UTC"
try:
zi: tzinfo = ZoneInfo(sname)
except (KeyError, OSError):
logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
zi = timezone.utc
except Exception as ex:
logger.warning(
"error handling PostgreSQL timezone: %r; will use UTC (%s - %s)",
sname,
type(ex).__name__,
ex,
)
zi = timezone.utc
_timezones[tzname] = zi
return zi

View File

@@ -0,0 +1,137 @@
"""
Wrappers for numeric types.
"""
# Copyright (C) 2020 The Psycopg Team
# Wrappers to force numbers to be cast as specific PostgreSQL types
# These types are implemented here but exposed by `psycopg.types.numeric`.
# They are defined here to avoid a circular import.
_MODULE = "psycopg.types.numeric"
class Int2(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`smallint/int2`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int2":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Int4(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`integer/int4`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int4":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Int8(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`bigint/int8`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Int8":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class IntNumeric(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`numeric/decimal`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "IntNumeric":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Float4(float):
"""
Force dumping a Python `!float` as a PostgreSQL :sql:`float4/real`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: float) -> "Float4":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Float8(float):
"""
Force dumping a Python `!float` as a PostgreSQL :sql:`float8/double precision`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: float) -> "Float8":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"
class Oid(int):
"""
Force dumping a Python `!int` as a PostgreSQL :sql:`oid`.
"""
__module__ = _MODULE
__slots__ = ()
def __new__(cls, arg: int) -> "Oid":
return super().__new__(cls, arg)
def __str__(self) -> str:
return super().__repr__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({super().__repr__()})"

View File

@@ -0,0 +1,251 @@
"""
Protocol objects representing different implementations of the same classes.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from typing import Any, Callable, Generator, Mapping
from typing import Protocol, Sequence, TYPE_CHECKING
from . import pq
from ._enums import PyFormat as PyFormat
from ._compat import TypeAlias, TypeVar
if TYPE_CHECKING:
from . import sql # noqa: F401
from .rows import Row, RowMaker
from .pq.abc import PGresult
from .waiting import Wait, Ready # noqa: F401
from ._compat import LiteralString # noqa: F401
from ._adapters_map import AdaptersMap
from ._connection_base import BaseConnection
NoneType: type = type(None)
# An object implementing the buffer protocol
Buffer: TypeAlias = "bytes | bytearray | memoryview"
Query: TypeAlias = "LiteralString | bytes | sql.SQL | sql.Composed"
Params: TypeAlias = "Sequence[Any] | Mapping[str, Any]"
ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
PipelineCommand: TypeAlias = Callable[[], None]
DumperKey: TypeAlias = "type | tuple[DumperKey, ...]"
ConnParam: TypeAlias = "str | int | None"
ConnDict: TypeAlias = "dict[str, ConnParam]"
ConnMapping: TypeAlias = Mapping[str, ConnParam]
# Waiting protocol types
RV = TypeVar("RV")
PQGenConn: TypeAlias = Generator["tuple[int, Wait]", "Ready | int", RV]
"""Generator for processes where the connection file number can change.
This can happen in connection and reset, but not in normal querying.
"""
PQGen: TypeAlias = Generator["Wait", "Ready | int", RV]
"""Generator for processes where the connection file number won't change.
"""
class WaitFunc(Protocol):
"""
Wait on the connection which generated `PQgen` and return its final result.
"""
def __call__(
self, gen: PQGen[RV], fileno: int, interval: float | None = None
) -> RV: ...
# Adaptation types
DumpFunc: TypeAlias = Callable[[Any], "Buffer | None"]
LoadFunc: TypeAlias = Callable[[Buffer], Any]
class AdaptContext(Protocol):
"""
A context describing how types are adapted.
Example of `~AdaptContext` are `~psycopg.Connection`, `~psycopg.Cursor`,
`~psycopg.adapt.Transformer`, `~psycopg.adapt.AdaptersMap`.
Note that this is a `~typing.Protocol`, so objects implementing
`!AdaptContext` don't need to explicitly inherit from this class.
"""
@property
def adapters(self) -> AdaptersMap:
"""The adapters configuration that this object uses."""
...
@property
def connection(self) -> BaseConnection[Any] | None:
"""The connection used by this object, if available.
:rtype: `~psycopg.Connection` or `~psycopg.AsyncConnection` or `!None`
"""
...
class Dumper(Protocol):
"""
Convert Python objects of type `!cls` to PostgreSQL representation.
"""
format: pq.Format
"""
The format that this class `dump()` method produces,
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
This is a class attribute.
"""
oid: int
"""The oid to pass to the server, if known; 0 otherwise (class attribute)."""
def __init__(self, cls: type, context: AdaptContext | None = None): ...
def dump(self, obj: Any) -> Buffer | None:
"""Convert the object `!obj` to PostgreSQL representation.
:param obj: the object to convert.
"""
...
def quote(self, obj: Any) -> Buffer:
"""Convert the object `!obj` to escaped representation.
:param obj: the object to convert.
"""
...
def get_key(self, obj: Any, format: PyFormat) -> DumperKey:
"""Return an alternative key to upgrade the dumper to represent `!obj`.
:param obj: The object to convert
:param format: The format to convert to
Normally the type of the object is all it takes to define how to dump
the object to the database. For instance, a Python `~datetime.date` can
be simply converted into a PostgreSQL :sql:`date`.
In a few cases, just the type is not enough. For example:
- A Python `~datetime.datetime` could be represented as a
:sql:`timestamptz` or a :sql:`timestamp`, according to whether it
specifies a `!tzinfo` or not.
- A Python int could be stored as several Postgres types: int2, int4,
int8, numeric. If a type too small is used, it may result in an
overflow. If a type too large is used, PostgreSQL may not want to
cast it to a smaller type.
- Python lists should be dumped according to the type they contain to
convert them to e.g. array of strings, array of ints (and which
size of int?...)
In these cases, a dumper can implement `!get_key()` and return a new
class, or sequence of classes, that can be used to identify the same
dumper again. If the mechanism is not needed, the method should return
the same `!cls` object passed in the constructor.
If a dumper implements `get_key()` it should also implement
`upgrade()`.
"""
...
def upgrade(self, obj: Any, format: PyFormat) -> Dumper:
"""Return a new dumper to manage `!obj`.
:param obj: The object to convert
:param format: The format to convert to
Once `Transformer.get_dumper()` has been notified by `get_key()` that
this Dumper class cannot handle `!obj` itself, it will invoke
`!upgrade()`, which should return a new `Dumper` instance, which will
be reused for every objects for which `!get_key()` returns the same
result.
"""
...
class Loader(Protocol):
"""
Convert PostgreSQL values with type OID `!oid` to Python objects.
"""
format: pq.Format
"""
The format that this class `load()` method can convert,
`~psycopg.pq.Format.TEXT` or `~psycopg.pq.Format.BINARY`.
This is a class attribute.
"""
def __init__(self, oid: int, context: AdaptContext | None = None): ...
def load(self, data: Buffer) -> Any:
"""
Convert the data returned by the database into a Python object.
:param data: the data to convert.
"""
...
class Transformer(Protocol):
types: tuple[int, ...] | None
formats: list[pq.Format] | None
def __init__(self, context: AdaptContext | None = None): ...
@classmethod
def from_context(cls, context: AdaptContext | None) -> Transformer: ...
@property
def connection(self) -> BaseConnection[Any] | None: ...
@property
def encoding(self) -> str: ...
@property
def adapters(self) -> AdaptersMap: ...
@property
def pgresult(self) -> PGresult | None: ...
def set_pgresult(
self,
result: PGresult | None,
*,
set_loaders: bool = True,
format: pq.Format | None = None,
) -> None: ...
def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Buffer | None]: ...
def as_literal(self, obj: Any) -> bytes: ...
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper: ...
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: ...
def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: ...
def load_sequence(self, record: Sequence[Buffer | None]) -> tuple[Any, ...]: ...
def get_loader(self, oid: int, format: pq.Format) -> Loader: ...

View File

@@ -0,0 +1,153 @@
"""
Entry point into the adaptation system.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, TYPE_CHECKING
from . import pq, abc
# Objects exported here
from ._enums import PyFormat as PyFormat
from ._transformer import Transformer as Transformer
from ._adapters_map import AdaptersMap as AdaptersMap # noqa: F401 # reexport
if TYPE_CHECKING:
from ._connection_base import BaseConnection
Buffer = abc.Buffer
ORD_BS = ord("\\")
class Dumper(abc.Dumper, ABC):
"""
Convert Python object of the type `!cls` to PostgreSQL representation.
"""
oid: int = 0
"""The oid to pass to the server, if known."""
format: pq.Format = pq.Format.TEXT
"""The format of the data dumped."""
def __init__(self, cls: type, context: abc.AdaptContext | None = None):
self.cls = cls
self.connection: BaseConnection[Any] | None
self.connection = context.connection if context else None
def __repr__(self) -> str:
return (
f"<{type(self).__module__}.{type(self).__qualname__}"
f" (oid={self.oid}) at 0x{id(self):x}>"
)
@abstractmethod
def dump(self, obj: Any) -> Buffer | None: ...
def quote(self, obj: Any) -> Buffer:
"""
By default return the `dump()` value quoted and sanitised, so
that the result can be used to build a SQL string. This works well
for most types and you won't likely have to implement this method in a
subclass.
"""
value = self.dump(obj)
if value is None:
return b"NULL"
if self.connection:
esc = pq.Escaping(self.connection.pgconn)
# escaping and quoting
return esc.escape_literal(value)
# This path is taken when quote is asked without a connection,
# usually it means by psycopg.sql.quote() or by
# 'Composible.as_string(None)'. Most often than not this is done by
# someone generating a SQL file to consume elsewhere.
# No quoting, only quote escaping, random bs escaping. See further.
esc = pq.Escaping()
out = esc.escape_string(value)
# b"\\" in memoryview doesn't work so search for the ascii value
if ORD_BS not in out:
# If the string has no backslash, the result is correct and we
# don't need to bother with standard_conforming_strings.
return b"'" + out + b"'"
# The libpq has a crazy behaviour: PQescapeString uses the last
# standard_conforming_strings setting seen on a connection. This
# means that backslashes might be escaped or might not.
#
# A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
# if scs is off, '\\' raises a warning and '\' is an error.
#
# Check what the libpq does, and if it doesn't escape the backslash
# let's do it on our own. Never mind the race condition.
rv: bytes = b" E'" + out + b"'"
if esc.escape_string(b"\\") == b"\\":
rv = rv.replace(b"\\", b"\\\\")
return rv
def get_key(self, obj: Any, format: PyFormat) -> abc.DumperKey:
"""
Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
This implementation returns the `!cls` passed in the constructor.
Subclasses needing to specialise the PostgreSQL type according to the
*value* of the object dumped (not only according to to its type)
should override this class.
"""
return self.cls
def upgrade(self, obj: Any, format: PyFormat) -> Dumper:
"""
Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
This implementation just returns `!self`. If a subclass implements
`get_key()` it should probably override `!upgrade()` too.
"""
return self
class Loader(abc.Loader, ABC):
"""
Convert PostgreSQL values with type OID `!oid` to Python objects.
"""
format: pq.Format = pq.Format.TEXT
"""The format of the data loaded."""
def __init__(self, oid: int, context: abc.AdaptContext | None = None):
self.oid = oid
self.connection: BaseConnection[Any] | None
self.connection = context.connection if context else None
@abstractmethod
def load(self, data: Buffer) -> Any:
"""Convert a PostgreSQL value to a Python object."""
...
class RecursiveDumper(Dumper):
"""Dumper with a transformer to help dumping recursive types."""
def __init__(self, cls: type, context: abc.AdaptContext | None = None):
super().__init__(cls, context)
self._tx = Transformer.from_context(context)
class RecursiveLoader(Loader):
"""Loader with a transformer to help loading recursive types."""
def __init__(self, oid: int, context: abc.AdaptContext | None = None):
super().__init__(oid, context)
self._tx = Transformer.from_context(context)

View File

@@ -0,0 +1,93 @@
"""
psycopg client-side binding cursors
"""
# Copyright (C) 2022 The Psycopg Team
from __future__ import annotations
from typing import TYPE_CHECKING
from functools import partial
from ._queries import PostgresQuery, PostgresClientQuery
from . import pq
from . import adapt
from . import errors as e
from .abc import ConnectionType, Query, Params
from .rows import Row
from .cursor import Cursor
from ._preparing import Prepare
from ._cursor_base import BaseCursor
from .cursor_async import AsyncCursor
if TYPE_CHECKING:
from typing import Any # noqa: F401
from .connection import Connection # noqa: F401
from .connection_async import AsyncConnection # noqa: F401
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
_query_cls = PostgresClientQuery
def mogrify(self, query: Query, params: Params | None = None) -> str:
"""
Return the query and parameters merged.
Parameters are adapted and merged to the query the same way that
`!execute()` would do.
"""
self._tx = adapt.Transformer(self)
pgq = self._convert_query(query, params)
return pgq.query.decode(self._tx.encoding)
def _execute_send(
self,
query: PostgresQuery,
*,
force_extended: bool = False,
binary: bool | None = None,
) -> None:
if binary is None:
fmt = self.format
else:
fmt = BINARY if binary else TEXT
if fmt == BINARY:
raise e.NotSupportedError(
"client-side cursors don't support binary results"
)
self._query = query
if self._conn._pipeline:
# In pipeline mode always use PQsendQueryParams - see #314
# Multiple statements in the same query are not allowed anyway.
self._conn._pipeline.command_queue.append(
partial(self._pgconn.send_query_params, query.query, None)
)
elif force_extended:
self._pgconn.send_query_params(query.query, None)
else:
# If we can, let's use simple query protocol,
# as it can execute more than one statement in a single query.
self._pgconn.send_query(query.query)
def _get_prepared(
self, pgq: PostgresQuery, prepare: bool | None = None
) -> tuple[Prepare, bytes]:
return (Prepare.NO, b"")
class ClientCursor(ClientCursorMixin["Connection[Any]", Row], Cursor[Row]):
__module__ = "psycopg"
class AsyncClientCursor(
ClientCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]
):
__module__ = "psycopg"

View File

@@ -0,0 +1,479 @@
# WARNING: this file is auto-generated by 'async_to_sync.py'
# from the original file 'connection_async.py'
# DO NOT CHANGE! Change the original file instead.
"""
Psycopg connection object (sync version)
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import logging
from time import monotonic
from types import TracebackType
from typing import Any, Generator, Iterator, cast, overload, TYPE_CHECKING
from contextlib import contextmanager
from . import pq
from . import errors as e
from . import waiting
from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV
from ._tpc import Xid
from .rows import Row, RowFactory, tuple_row, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
from .conninfo import make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts, timeout_from_conninfo
from ._pipeline import Pipeline
from .generators import notifies
from .transaction import Transaction
from .cursor import Cursor
from ._capabilities import capabilities
from .server_cursor import ServerCursor
from ._connection_base import BaseConnection, CursorRow, Notify
from threading import Lock
if TYPE_CHECKING:
from .pq.abc import PGconn
_WAIT_INTERVAL = 0.1
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
ACTIVE = pq.TransactionStatus.ACTIVE
INTRANS = pq.TransactionStatus.INTRANS
_INTERRUPTED = KeyboardInterrupt
logger = logging.getLogger("psycopg")
class Connection(BaseConnection[Row]):
"""
Wrapper for a connection to the database.
"""
__module__ = "psycopg"
cursor_factory: type[Cursor[Row]]
server_cursor_factory: type[ServerCursor[Row]]
row_factory: RowFactory[Row]
_pipeline: Pipeline | None
def __init__(
self,
pgconn: PGconn,
row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
):
super().__init__(pgconn)
self.row_factory = row_factory
self.lock = Lock()
self.cursor_factory = Cursor
self.server_cursor_factory = ServerCursor
@classmethod
def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: int | None = 5,
context: AdaptContext | None = None,
row_factory: RowFactory[Row] | None = None,
cursor_factory: type[Cursor[Row]] | None = None,
**kwargs: ConnParam,
) -> Self:
"""
Connect to a database server and return a new `Connection` instance.
"""
params = cls._get_connection_params(conninfo, **kwargs)
timeout = timeout_from_conninfo(params)
rv = None
attempts = conninfo_attempts(params)
for attempt in attempts:
try:
conninfo = make_conninfo("", **attempt)
gen = cls._connect_gen(conninfo, timeout=timeout)
rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL)
except e._NO_TRACEBACK as ex:
if len(attempts) > 1:
logger.debug(
"connection attempt failed: host: %r port: %r, hostaddr %r: %s",
attempt.get("host"),
attempt.get("port"),
attempt.get("hostaddr"),
str(ex),
)
last_ex = ex
else:
break
if not rv:
assert last_ex
raise last_ex.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
rv.row_factory = row_factory
if cursor_factory:
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
return rv
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.closed:
return
if exc_type:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
self.rollback()
except Exception as exc2:
logger.warning("error ignored in rollback on %s: %s", self, exc2)
else:
self.commit()
# Close the connection only if it doesn't belong to a pool.
if not getattr(self, "_pool", None):
self.close()
@classmethod
def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
"""Manipulate connection parameters before connecting."""
return conninfo_to_dict(conninfo, **kwargs)
def close(self) -> None:
"""Close the database connection."""
if self.closed:
return
self._closed = True
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
self.pgconn.finish()
@overload
def cursor(self, *, binary: bool = False) -> Cursor[Row]: ...
@overload
def cursor(
self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
) -> Cursor[CursorRow]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
scrollable: bool | None = None,
withhold: bool = False,
) -> ServerCursor[Row]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
row_factory: RowFactory[CursorRow],
scrollable: bool | None = None,
withhold: bool = False,
) -> ServerCursor[CursorRow]: ...
def cursor(
self,
name: str = "",
*,
binary: bool = False,
row_factory: RowFactory[Any] | None = None,
scrollable: bool | None = None,
withhold: bool = False,
) -> Cursor[Any] | ServerCursor[Any]:
"""
Return a new `Cursor` to send commands and queries to the connection.
"""
self._check_connection_ok()
if not row_factory:
row_factory = self.row_factory
cur: Cursor[Any] | ServerCursor[Any]
if name:
cur = self.server_cursor_factory(
self,
name=name,
row_factory=row_factory,
scrollable=scrollable,
withhold=withhold,
)
else:
cur = self.cursor_factory(self, row_factory=row_factory)
if binary:
cur.format = BINARY
return cur
def execute(
self,
query: Query,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool = False,
) -> Cursor[Row]:
"""Execute a query and return a cursor to read its results."""
try:
cur = self.cursor()
if binary:
cur.format = BINARY
return cur.execute(query, params, prepare=prepare)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def commit(self) -> None:
"""Commit any pending transaction to the database."""
with self.lock:
self.wait(self._commit_gen())
def rollback(self) -> None:
"""Roll back to the start of any pending transaction."""
with self.lock:
self.wait(self._rollback_gen())
def cancel_safe(self, *, timeout: float = 30.0) -> None:
"""Cancel the current operation on the connection.
:param timeout: raise a `~errors.CancellationTimeout` if the
cancellation request does not succeed within `timeout` seconds.
Note that a successful cancel attempt on the client is not a guarantee
that the server will successfully manage to cancel the operation.
This is a non-blocking version of `~Connection.cancel()` which
leverages a more secure and improved cancellation feature of the libpq,
which is only available from version 17.
If the underlying libpq is older than version 17, the method will fall
back to using the same implementation of `!cancel()`.
"""
if not self._should_cancel():
return
if capabilities.has_cancel_safe():
waiting.wait_conn(
self._cancel_gen(timeout=timeout), interval=_WAIT_INTERVAL
)
else:
self.cancel()
def _try_cancel(self, *, timeout: float = 5.0) -> None:
try:
self.cancel_safe(timeout=timeout)
except Exception as ex:
logger.warning("query cancellation failed: %s", ex)
@contextmanager
def transaction(
self, savepoint_name: str | None = None, force_rollback: bool = False
) -> Iterator[Transaction]:
"""
Start a context block with a new transaction or nested transaction.
:param savepoint_name: Name of the savepoint used to manage a nested
transaction. If `!None`, one will be chosen automatically.
:param force_rollback: Roll back the transaction at the end of the
block even if there were no error (e.g. to try a no-op process).
:rtype: Transaction
"""
tx = Transaction(self, savepoint_name, force_rollback)
if self._pipeline:
with self.pipeline(), tx, self.pipeline():
yield tx
else:
with tx:
yield tx
def notifies(
self, *, timeout: float | None = None, stop_after: int | None = None
) -> Generator[Notify, None, None]:
"""
Yield `Notify` objects as soon as they are received from the database.
:param timeout: maximum amount of time to wait for notifications.
`!None` means no timeout.
:param stop_after: stop after receiving this number of notifications.
You might actually receive more than this number if more than one
notifications arrives in the same packet.
"""
# Allow interrupting the wait with a signal by reducing a long timeout
# into shorter intervals.
if timeout is not None:
deadline = monotonic() + timeout
interval = min(timeout, _WAIT_INTERVAL)
else:
deadline = None
interval = _WAIT_INTERVAL
nreceived = 0
with self.lock:
enc = self.pgconn._encoding
while True:
try:
ns = self.wait(notifies(self.pgconn), interval=interval)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# Emit the notifications received.
for pgn in ns:
n = Notify(
pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
)
yield n
nreceived += 1
# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break
# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:
break
@contextmanager
def pipeline(self) -> Iterator[Pipeline]:
"""Context manager to switch the connection into pipeline mode."""
with self.lock:
self._check_connection_ok()
pipeline = self._pipeline
if pipeline is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = Pipeline(self)
try:
with pipeline:
yield pipeline
finally:
if pipeline.level == 0:
with self.lock:
assert pipeline is self._pipeline
self._pipeline = None
def wait(self, gen: PQGen[RV], interval: float | None = _WAIT_INTERVAL) -> RV:
"""
Consume a generator operating on the connection.
The function must be used on generators that don't change connection
fd (i.e. not on connect and reset).
"""
try:
return waiting.wait(gen, self.pgconn.socket, interval=interval)
except _INTERRUPTED:
if self.pgconn.transaction_status == ACTIVE:
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
self._try_cancel(timeout=5.0)
try:
waiting.wait(gen, self.pgconn.socket, interval=interval)
except e.QueryCanceled:
pass # as expected
raise
def _set_autocommit(self, value: bool) -> None:
self.set_autocommit(value)
def set_autocommit(self, value: bool) -> None:
"""Method version of the `~Connection.autocommit` setter."""
with self.lock:
self.wait(self._set_autocommit_gen(value))
def _set_isolation_level(self, value: IsolationLevel | None) -> None:
self.set_isolation_level(value)
def set_isolation_level(self, value: IsolationLevel | None) -> None:
"""Method version of the `~Connection.isolation_level` setter."""
with self.lock:
self.wait(self._set_isolation_level_gen(value))
def _set_read_only(self, value: bool | None) -> None:
self.set_read_only(value)
def set_read_only(self, value: bool | None) -> None:
"""Method version of the `~Connection.read_only` setter."""
with self.lock:
self.wait(self._set_read_only_gen(value))
def _set_deferrable(self, value: bool | None) -> None:
self.set_deferrable(value)
def set_deferrable(self, value: bool | None) -> None:
"""Method version of the `~Connection.deferrable` setter."""
with self.lock:
self.wait(self._set_deferrable_gen(value))
def tpc_begin(self, xid: Xid | str) -> None:
"""
Begin a TPC transaction with the given transaction ID `!xid`.
"""
with self.lock:
self.wait(self._tpc_begin_gen(xid))
def tpc_prepare(self) -> None:
"""
Perform the first phase of a transaction started with `tpc_begin()`.
"""
try:
with self.lock:
self.wait(self._tpc_prepare_gen())
except e.ObjectNotInPrerequisiteState as ex:
raise e.NotSupportedError(str(ex)) from None
def tpc_commit(self, xid: Xid | str | None = None) -> None:
"""
Commit a prepared two-phase transaction.
"""
with self.lock:
self.wait(self._tpc_finish_gen("COMMIT", xid))
def tpc_rollback(self, xid: Xid | str | None = None) -> None:
"""
Roll back a prepared two-phase transaction.
"""
with self.lock:
self.wait(self._tpc_finish_gen("ROLLBACK", xid))
def tpc_recover(self) -> list[Xid]:
self._check_tpc()
status = self.info.transaction_status
with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
cur.execute(Xid._get_recover_query())
res = cur.fetchall()
if status == IDLE and self.info.transaction_status == INTRANS:
self.rollback()
return res

View File

@@ -0,0 +1,519 @@
"""
Psycopg connection object (async version)
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import logging
from time import monotonic
from types import TracebackType
from typing import Any, AsyncGenerator, AsyncIterator, cast, overload, TYPE_CHECKING
from contextlib import asynccontextmanager
from . import pq
from . import errors as e
from . import waiting
from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV
from ._tpc import Xid
from .rows import Row, AsyncRowFactory, tuple_row, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from ._compat import Self
from .conninfo import make_conninfo, conninfo_to_dict
from .conninfo import conninfo_attempts_async, timeout_from_conninfo
from ._pipeline import AsyncPipeline
from .generators import notifies
from .transaction import AsyncTransaction
from .cursor_async import AsyncCursor
from ._capabilities import capabilities
from .server_cursor import AsyncServerCursor
from ._connection_base import BaseConnection, CursorRow, Notify
if True: # ASYNC
import sys
import asyncio
from asyncio import Lock
from ._compat import to_thread
else:
from threading import Lock
if TYPE_CHECKING:
from .pq.abc import PGconn
_WAIT_INTERVAL = 0.1
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
ACTIVE = pq.TransactionStatus.ACTIVE
INTRANS = pq.TransactionStatus.INTRANS
if True: # ASYNC
_INTERRUPTED = (asyncio.CancelledError, KeyboardInterrupt)
else:
_INTERRUPTED = KeyboardInterrupt
logger = logging.getLogger("psycopg")
class AsyncConnection(BaseConnection[Row]):
"""
Wrapper for a connection to the database.
"""
__module__ = "psycopg"
cursor_factory: type[AsyncCursor[Row]]
server_cursor_factory: type[AsyncServerCursor[Row]]
row_factory: AsyncRowFactory[Row]
_pipeline: AsyncPipeline | None
def __init__(
self,
pgconn: PGconn,
row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
):
super().__init__(pgconn)
self.row_factory = row_factory
self.lock = Lock()
self.cursor_factory = AsyncCursor
self.server_cursor_factory = AsyncServerCursor
@classmethod
async def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: int | None = 5,
context: AdaptContext | None = None,
row_factory: AsyncRowFactory[Row] | None = None,
cursor_factory: type[AsyncCursor[Row]] | None = None,
**kwargs: ConnParam,
) -> Self:
"""
Connect to a database server and return a new `AsyncConnection` instance.
"""
if True: # ASYNC
if sys.platform == "win32":
loop = asyncio.get_running_loop()
if isinstance(loop, asyncio.ProactorEventLoop):
raise e.InterfaceError(
"Psycopg cannot use the 'ProactorEventLoop' to run in async"
" mode. Please use a compatible event loop, for instance by"
" setting 'asyncio.set_event_loop_policy"
"(WindowsSelectorEventLoopPolicy())'"
)
params = await cls._get_connection_params(conninfo, **kwargs)
timeout = timeout_from_conninfo(params)
rv = None
attempts = await conninfo_attempts_async(params)
for attempt in attempts:
try:
conninfo = make_conninfo("", **attempt)
gen = cls._connect_gen(conninfo, timeout=timeout)
rv = await waiting.wait_conn_async(gen, interval=_WAIT_INTERVAL)
except e._NO_TRACEBACK as ex:
if len(attempts) > 1:
logger.debug(
"connection attempt failed: host: %r port: %r, hostaddr %r: %s",
attempt.get("host"),
attempt.get("port"),
attempt.get("hostaddr"),
str(ex),
)
last_ex = ex
else:
break
if not rv:
assert last_ex
raise last_ex.with_traceback(None)
rv._autocommit = bool(autocommit)
if row_factory:
rv.row_factory = row_factory
if cursor_factory:
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
return rv
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.closed:
return
if exc_type:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
await self.rollback()
except Exception as exc2:
logger.warning("error ignored in rollback on %s: %s", self, exc2)
else:
await self.commit()
# Close the connection only if it doesn't belong to a pool.
if not getattr(self, "_pool", None):
await self.close()
@classmethod
async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
"""Manipulate connection parameters before connecting."""
return conninfo_to_dict(conninfo, **kwargs)
async def close(self) -> None:
"""Close the database connection."""
if self.closed:
return
self._closed = True
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
self.pgconn.finish()
@overload
def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ...
@overload
def cursor(
self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
) -> AsyncCursor[CursorRow]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
scrollable: bool | None = None,
withhold: bool = False,
) -> AsyncServerCursor[Row]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
row_factory: AsyncRowFactory[CursorRow],
scrollable: bool | None = None,
withhold: bool = False,
) -> AsyncServerCursor[CursorRow]: ...
def cursor(
self,
name: str = "",
*,
binary: bool = False,
row_factory: AsyncRowFactory[Any] | None = None,
scrollable: bool | None = None,
withhold: bool = False,
) -> AsyncCursor[Any] | AsyncServerCursor[Any]:
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
"""
self._check_connection_ok()
if not row_factory:
row_factory = self.row_factory
cur: AsyncCursor[Any] | AsyncServerCursor[Any]
if name:
cur = self.server_cursor_factory(
self,
name=name,
row_factory=row_factory,
scrollable=scrollable,
withhold=withhold,
)
else:
cur = self.cursor_factory(self, row_factory=row_factory)
if binary:
cur.format = BINARY
return cur
async def execute(
self,
query: Query,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool = False,
) -> AsyncCursor[Row]:
"""Execute a query and return a cursor to read its results."""
try:
cur = self.cursor()
if binary:
cur.format = BINARY
return await cur.execute(query, params, prepare=prepare)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def commit(self) -> None:
"""Commit any pending transaction to the database."""
async with self.lock:
await self.wait(self._commit_gen())
async def rollback(self) -> None:
"""Roll back to the start of any pending transaction."""
async with self.lock:
await self.wait(self._rollback_gen())
async def cancel_safe(self, *, timeout: float = 30.0) -> None:
"""Cancel the current operation on the connection.
:param timeout: raise a `~errors.CancellationTimeout` if the
cancellation request does not succeed within `timeout` seconds.
Note that a successful cancel attempt on the client is not a guarantee
that the server will successfully manage to cancel the operation.
This is a non-blocking version of `~Connection.cancel()` which
leverages a more secure and improved cancellation feature of the libpq,
which is only available from version 17.
If the underlying libpq is older than version 17, the method will fall
back to using the same implementation of `!cancel()`.
"""
if not self._should_cancel():
return
if capabilities.has_cancel_safe():
await waiting.wait_conn_async(
self._cancel_gen(timeout=timeout), interval=_WAIT_INTERVAL
)
else:
if True: # ASYNC
await to_thread(self.cancel)
else:
self.cancel()
async def _try_cancel(self, *, timeout: float = 5.0) -> None:
try:
await self.cancel_safe(timeout=timeout)
except Exception as ex:
logger.warning("query cancellation failed: %s", ex)
@asynccontextmanager
async def transaction(
self, savepoint_name: str | None = None, force_rollback: bool = False
) -> AsyncIterator[AsyncTransaction]:
"""
Start a context block with a new transaction or nested transaction.
:param savepoint_name: Name of the savepoint used to manage a nested
transaction. If `!None`, one will be chosen automatically.
:param force_rollback: Roll back the transaction at the end of the
block even if there were no error (e.g. to try a no-op process).
:rtype: AsyncTransaction
"""
tx = AsyncTransaction(self, savepoint_name, force_rollback)
if self._pipeline:
async with self.pipeline(), tx, self.pipeline():
yield tx
else:
async with tx:
yield tx
async def notifies(
self, *, timeout: float | None = None, stop_after: int | None = None
) -> AsyncGenerator[Notify, None]:
"""
Yield `Notify` objects as soon as they are received from the database.
:param timeout: maximum amount of time to wait for notifications.
`!None` means no timeout.
:param stop_after: stop after receiving this number of notifications.
You might actually receive more than this number if more than one
notifications arrives in the same packet.
"""
# Allow interrupting the wait with a signal by reducing a long timeout
# into shorter intervals.
if timeout is not None:
deadline = monotonic() + timeout
interval = min(timeout, _WAIT_INTERVAL)
else:
deadline = None
interval = _WAIT_INTERVAL
nreceived = 0
async with self.lock:
enc = self.pgconn._encoding
while True:
try:
ns = await self.wait(notifies(self.pgconn), interval=interval)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# Emit the notifications received.
for pgn in ns:
n = Notify(
pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
)
yield n
nreceived += 1
# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break
# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:
break
@asynccontextmanager
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
"""Context manager to switch the connection into pipeline mode."""
async with self.lock:
self._check_connection_ok()
pipeline = self._pipeline
if pipeline is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = AsyncPipeline(self)
try:
async with pipeline:
yield pipeline
finally:
if pipeline.level == 0:
async with self.lock:
assert pipeline is self._pipeline
self._pipeline = None
async def wait(self, gen: PQGen[RV], interval: float | None = _WAIT_INTERVAL) -> RV:
"""
Consume a generator operating on the connection.
The function must be used on generators that don't change connection
fd (i.e. not on connect and reset).
"""
try:
return await waiting.wait_async(gen, self.pgconn.socket, interval=interval)
except _INTERRUPTED:
if self.pgconn.transaction_status == ACTIVE:
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
await self._try_cancel(timeout=5.0)
try:
await waiting.wait_async(gen, self.pgconn.socket, interval=interval)
except e.QueryCanceled:
pass # as expected
raise
def _set_autocommit(self, value: bool) -> None:
if True: # ASYNC
self._no_set_async("autocommit")
else:
self.set_autocommit(value)
async def set_autocommit(self, value: bool) -> None:
"""Method version of the `~Connection.autocommit` setter."""
async with self.lock:
await self.wait(self._set_autocommit_gen(value))
def _set_isolation_level(self, value: IsolationLevel | None) -> None:
if True: # ASYNC
self._no_set_async("isolation_level")
else:
self.set_isolation_level(value)
async def set_isolation_level(self, value: IsolationLevel | None) -> None:
"""Method version of the `~Connection.isolation_level` setter."""
async with self.lock:
await self.wait(self._set_isolation_level_gen(value))
def _set_read_only(self, value: bool | None) -> None:
if True: # ASYNC
self._no_set_async("read_only")
else:
self.set_read_only(value)
async def set_read_only(self, value: bool | None) -> None:
"""Method version of the `~Connection.read_only` setter."""
async with self.lock:
await self.wait(self._set_read_only_gen(value))
def _set_deferrable(self, value: bool | None) -> None:
if True: # ASYNC
self._no_set_async("deferrable")
else:
self.set_deferrable(value)
async def set_deferrable(self, value: bool | None) -> None:
"""Method version of the `~Connection.deferrable` setter."""
async with self.lock:
await self.wait(self._set_deferrable_gen(value))
if True: # ASYNC
def _no_set_async(self, attribute: str) -> None:
raise AttributeError(
f"'the {attribute!r} property is read-only on async connections:"
f" please use 'await .set_{attribute}()' instead."
)
async def tpc_begin(self, xid: Xid | str) -> None:
"""
Begin a TPC transaction with the given transaction ID `!xid`.
"""
async with self.lock:
await self.wait(self._tpc_begin_gen(xid))
async def tpc_prepare(self) -> None:
"""
Perform the first phase of a transaction started with `tpc_begin()`.
"""
try:
async with self.lock:
await self.wait(self._tpc_prepare_gen())
except e.ObjectNotInPrerequisiteState as ex:
raise e.NotSupportedError(str(ex)) from None
async def tpc_commit(self, xid: Xid | str | None = None) -> None:
"""
Commit a prepared two-phase transaction.
"""
async with self.lock:
await self.wait(self._tpc_finish_gen("COMMIT", xid))
async def tpc_rollback(self, xid: Xid | str | None = None) -> None:
"""
Roll back a prepared two-phase transaction.
"""
async with self.lock:
await self.wait(self._tpc_finish_gen("ROLLBACK", xid))
async def tpc_recover(self) -> list[Xid]:
self._check_tpc()
status = self.info.transaction_status
async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
await cur.execute(Xid._get_recover_query())
res = await cur.fetchall()
if status == IDLE and self.info.transaction_status == INTRANS:
await self.rollback()
return res

View File

@@ -0,0 +1,154 @@
"""
Functions to manipulate conninfo strings
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import re
from . import pq
from . import errors as e
from . import _conninfo_utils
from . import _conninfo_attempts
from . import _conninfo_attempts_async
from .abc import ConnParam, ConnDict
# re-exoprts
conninfo_attempts = _conninfo_attempts.conninfo_attempts
conninfo_attempts_async = _conninfo_attempts_async.conninfo_attempts_async
# Default timeout for connection a attempt.
# Arbitrary timeout, what applied by the libpq on my computer.
# Your mileage won't vary.
_DEFAULT_CONNECT_TIMEOUT = 130
def make_conninfo(conninfo: str = "", **kwargs: ConnParam) -> str:
"""
Merge a string and keyword params into a single conninfo string.
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
:return: A connection string valid for PostgreSQL, with the `!kwargs`
parameters merged.
Raise `~psycopg.ProgrammingError` if the input doesn't make a valid
conninfo string.
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
#LIBPQ-CONNSTRING
"""
if not conninfo and not kwargs:
return ""
# If no kwarg specified don't mung the conninfo but check if it's correct.
# Make sure to return a string, not a subtype, to avoid making Liskov sad.
if not kwargs:
_parse_conninfo(conninfo)
return str(conninfo)
# Override the conninfo with the parameters
# Drop the None arguments
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
if conninfo:
tmp = conninfo_to_dict(conninfo)
tmp.update(kwargs)
kwargs = tmp
conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
# Verify the result is valid
_parse_conninfo(conninfo)
return conninfo
def conninfo_to_dict(conninfo: str = "", **kwargs: ConnParam) -> ConnDict:
"""
Convert the `!conninfo` string into a dictionary of parameters.
:param conninfo: A `connection string`__ as accepted by PostgreSQL.
:param kwargs: Parameters overriding the ones specified in `!conninfo`.
:return: Dictionary with the parameters parsed from `!conninfo` and
`!kwargs`.
Raise `~psycopg.ProgrammingError` if `!conninfo` is not a a valid connection
string.
.. __: https://www.postgresql.org/docs/current/libpq-connect.html
#LIBPQ-CONNSTRING
"""
opts = _parse_conninfo(conninfo)
rv: ConnDict = {
opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None
}
for k, v in kwargs.items():
if v is not None:
rv[k] = v
return rv
def _parse_conninfo(conninfo: str) -> list[pq.ConninfoOption]:
"""
Verify that `!conninfo` is a valid connection string.
Raise ProgrammingError if the string is not valid.
Return the result of pq.Conninfo.parse() on success.
"""
try:
return pq.Conninfo.parse(conninfo.encode())
except e.OperationalError as ex:
raise e.ProgrammingError(str(ex)) from None
re_escape = re.compile(r"([\\'])")
re_space = re.compile(r"\s")
def _param_escape(s: str) -> str:
"""
Apply the escaping rule required by PQconnectdb
"""
if not s:
return "''"
s = re_escape.sub(r"\\\1", s)
if re_space.search(s):
s = "'" + s + "'"
return s
def timeout_from_conninfo(params: ConnDict) -> int:
"""
Return the timeout in seconds from the connection parameters.
"""
# Follow the libpq convention:
#
# - 0 or less means no timeout (but we will use a default to simulate
# the socket timeout)
# - at least 2 seconds.
#
# See connectDBComplete in fe-connect.c
value: str | int | None = _conninfo_utils.get_param(params, "connect_timeout")
if value is None:
value = _DEFAULT_CONNECT_TIMEOUT
try:
timeout = int(float(value))
except ValueError:
raise e.ProgrammingError(f"bad value for connect_timeout: {value!r}") from None
if timeout <= 0:
# The sync connect function will stop on the default socket timeout
# Because in async connection mode we need to enforce the timeout
# ourselves, we need a finite value.
timeout = _DEFAULT_CONNECT_TIMEOUT
elif timeout < 2:
# Enforce a 2s min
timeout = 2
return timeout

View File

@@ -0,0 +1,35 @@
"""
Module gathering the various parts of the copy subsystem.
"""
from typing import IO
from .abc import Buffer
from . import _copy, _copy_async
# re-exports
AsyncCopy = _copy_async.AsyncCopy
AsyncWriter = _copy_async.AsyncWriter
AsyncLibpqWriter = _copy_async.AsyncLibpqWriter
AsyncQueuedLibpqWriter = _copy_async.AsyncQueuedLibpqWriter
Copy = _copy.Copy
Writer = _copy.Writer
LibpqWriter = _copy.LibpqWriter
QueuedLibpqWriter = _copy.QueuedLibpqWriter
class FileWriter(Writer):
"""
A `Writer` to write copy data to a file-like object.
:param file: the file where to write copy data. It must be open for writing
in binary mode.
"""
def __init__(self, file: IO[bytes]):
self.file = file
def write(self, data: Buffer) -> None:
self.file.write(data)

View File

@@ -0,0 +1,20 @@
"""
CockroachDB support package.
"""
# Copyright (C) 2022 The Psycopg Team
from . import _types
from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
adapters = _types.adapters # exposed by the package
connect = CrdbConnection.connect
_types.register_crdb_types(adapters.types)
_types.register_crdb_adapters(adapters)
__all__ = [
"AsyncCrdbConnection",
"CrdbConnection",
"CrdbConnectionInfo",
]

View File

@@ -0,0 +1,201 @@
"""
Types configuration specific for CockroachDB.
"""
# Copyright (C) 2022 The Psycopg Team
from enum import Enum
from .._typeinfo import TypeInfo, TypesRegistry
from ..abc import AdaptContext, NoneType
from .._oids import TEXT_OID
from .._typemod import BitTypeModifier, CharTypeModifier, NumericTypeModifier
from .._typemod import TimeTypeModifier
from .._adapters_map import AdaptersMap
from ..types.enum import EnumDumper, EnumBinaryDumper
from ..types.none import NoneDumper
types = TypesRegistry()
# Global adapter maps with PostgreSQL types configuration
adapters = AdaptersMap(types=types)
class CrdbEnumDumper(EnumDumper):
oid = TEXT_OID
class CrdbEnumBinaryDumper(EnumBinaryDumper):
oid = TEXT_OID
class CrdbNoneDumper(NoneDumper):
oid = TEXT_OID
def register_crdb_adapters(context: AdaptContext) -> None:
from .. import dbapi20
from ..types import array
_register_postgres_adapters(context)
# String must come after enum and none to map text oid -> string dumper
_register_crdb_none_adapters(context)
_register_crdb_enum_adapters(context)
_register_crdb_string_adapters(context)
_register_crdb_json_adapters(context)
_register_crdb_net_adapters(context)
dbapi20.register_dbapi20_adapters(adapters)
array.register_all_arrays(adapters)
def _register_postgres_adapters(context: AdaptContext) -> None:
# Same adapters used by PostgreSQL, or a good starting point for customization
from ..types import array, bool, composite, datetime
from ..types import numeric, numpy, string, uuid
array.register_default_adapters(context)
composite.register_default_adapters(context)
datetime.register_default_adapters(context)
string.register_default_adapters(context)
uuid.register_default_adapters(context)
# Both numpy Decimal and uint64 dumpers use the numeric oid, but the former
# covers the entire numeric domain, whereas the latter only deals with
# integers. For this reason, if we specify dumpers by oid, we want to make
# sure to get the Decimal dumper. We enforce that by registering the
# numeric dumpers last.
numpy.register_default_adapters(context)
bool.register_default_adapters(context)
numeric.register_default_adapters(context)
def _register_crdb_string_adapters(context: AdaptContext) -> None:
from ..types import string
# Dump strings with text oid instead of unknown.
# Unlike PostgreSQL, CRDB seems able to cast text to most types.
context.adapters.register_dumper(str, string.StrDumper)
context.adapters.register_dumper(str, string.StrBinaryDumper)
def _register_crdb_enum_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
context.adapters.register_dumper(Enum, CrdbEnumDumper)
def _register_crdb_json_adapters(context: AdaptContext) -> None:
from ..types import json
adapters = context.adapters
# CRDB doesn't have json/jsonb: both names map to the jsonb oid
adapters.register_dumper(json.Json, json.JsonbBinaryDumper)
adapters.register_dumper(json.Json, json.JsonbDumper)
adapters.register_dumper(json.Jsonb, json.JsonbBinaryDumper)
adapters.register_dumper(json.Jsonb, json.JsonbDumper)
adapters.register_loader("json", json.JsonLoader)
adapters.register_loader("jsonb", json.JsonbLoader)
adapters.register_loader("json", json.JsonBinaryLoader)
adapters.register_loader("jsonb", json.JsonbBinaryLoader)
def _register_crdb_net_adapters(context: AdaptContext) -> None:
from ..types import net
adapters = context.adapters
adapters.register_dumper("ipaddress.IPv4Address", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Address", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceDumper)
adapters.register_dumper("ipaddress.IPv4Address", net.AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Address", net.AddressBinaryDumper)
adapters.register_dumper("ipaddress.IPv4Interface", net.InterfaceBinaryDumper)
adapters.register_dumper("ipaddress.IPv6Interface", net.InterfaceBinaryDumper)
adapters.register_dumper(None, net.InetBinaryDumper)
adapters.register_loader("inet", net.InetLoader)
adapters.register_loader("inet", net.InetBinaryLoader)
def _register_crdb_none_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(NoneType, CrdbNoneDumper)
def register_crdb_types(types: TypesRegistry) -> None:
for t in [
TypeInfo("json", 3802, 3807, regtype="jsonb"), # Alias json -> jsonb.
TypeInfo("int8", 20, 1016, regtype="integer"), # Alias integer -> int8
TypeInfo('"char"', 18, 1002), # special case, not generated
# autogenerated: start
# Generated from CockroachDB 23.1.10
TypeInfo("bit", 1560, 1561, typemod=BitTypeModifier),
TypeInfo("bool", 16, 1000, regtype="boolean"),
TypeInfo("bpchar", 1042, 1014, regtype="character", typemod=CharTypeModifier),
TypeInfo("bytea", 17, 1001),
TypeInfo("date", 1082, 1182),
TypeInfo("float4", 700, 1021, regtype="real"),
TypeInfo("float8", 701, 1022, regtype="double precision"),
TypeInfo("inet", 869, 1041),
TypeInfo("int2", 21, 1005, regtype="smallint"),
TypeInfo("int2vector", 22, 1006),
TypeInfo("int4", 23, 1007),
TypeInfo("int8", 20, 1016, regtype="bigint"),
TypeInfo("interval", 1186, 1187, typemod=TimeTypeModifier),
TypeInfo("jsonb", 3802, 3807),
TypeInfo("name", 19, 1003),
TypeInfo("numeric", 1700, 1231, typemod=NumericTypeModifier),
TypeInfo("oid", 26, 1028),
TypeInfo("oidvector", 30, 1013),
TypeInfo("record", 2249, 2287),
TypeInfo("regclass", 2205, 2210),
TypeInfo("regnamespace", 4089, 4090),
TypeInfo("regproc", 24, 1008),
TypeInfo("regprocedure", 2202, 2207),
TypeInfo("regrole", 4096, 4097),
TypeInfo("regtype", 2206, 2211),
TypeInfo("text", 25, 1009),
TypeInfo(
"time",
1083,
1183,
regtype="time without time zone",
typemod=TimeTypeModifier,
),
TypeInfo(
"timestamp",
1114,
1115,
regtype="timestamp without time zone",
typemod=TimeTypeModifier,
),
TypeInfo(
"timestamptz",
1184,
1185,
regtype="timestamp with time zone",
typemod=TimeTypeModifier,
),
TypeInfo(
"timetz",
1266,
1270,
regtype="time with time zone",
typemod=TimeTypeModifier,
),
TypeInfo("tsquery", 3615, 3645),
TypeInfo("tsvector", 3614, 3643),
TypeInfo("unknown", 705, 0),
TypeInfo("uuid", 2950, 2951),
TypeInfo("varbit", 1562, 1563, regtype="bit varying", typemod=BitTypeModifier),
TypeInfo(
"varchar", 1043, 1015, regtype="character varying", typemod=CharTypeModifier
),
# autogenerated: end
]:
types.add(t)

View File

@@ -0,0 +1,105 @@
"""
CockroachDB-specific connections.
"""
# Copyright (C) 2022 The Psycopg Team
from __future__ import annotations
import re
from typing import Any, TYPE_CHECKING
from .. import errors as e
from ..rows import Row
from ..connection import Connection
from .._adapters_map import AdaptersMap
from .._connection_info import ConnectionInfo
from ..connection_async import AsyncConnection
from ._types import adapters
if TYPE_CHECKING:
from ..pq.abc import PGconn
class _CrdbConnectionMixin:
_adapters: AdaptersMap | None
pgconn: PGconn
@classmethod
def is_crdb(cls, conn: Connection[Any] | AsyncConnection[Any] | PGconn) -> bool:
"""
Return `!True` if the server connected to `!conn` is CockroachDB.
"""
if isinstance(conn, (Connection, AsyncConnection)):
conn = conn.pgconn
return bool(conn.parameter_status(b"crdb_version"))
@property
def adapters(self) -> AdaptersMap:
if not self._adapters:
# By default, use CockroachDB adapters map
self._adapters = AdaptersMap(adapters)
return self._adapters
@property
def info(self) -> CrdbConnectionInfo:
return CrdbConnectionInfo(self.pgconn)
def _check_tpc(self) -> None:
if self.is_crdb(self.pgconn):
raise e.NotSupportedError("CockroachDB doesn't support prepared statements")
class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
"""
Wrapper for a connection to a CockroachDB database.
"""
__module__ = "psycopg.crdb"
class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
"""
Wrapper for an async connection to a CockroachDB database.
"""
__module__ = "psycopg.crdb"
class CrdbConnectionInfo(ConnectionInfo):
"""
`~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
"""
__module__ = "psycopg.crdb"
@property
def vendor(self) -> str:
return "CockroachDB"
@property
def server_version(self) -> int:
"""
Return the CockroachDB server version connected.
Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
"""
sver = self.parameter_status("crdb_version")
if not sver:
raise e.InternalError("'crdb_version' parameter status not set")
ver = self.parse_crdb_version(sver)
if ver is None:
raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
return ver
@classmethod
def parse_crdb_version(self, sver: str) -> int | None:
m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
if not m:
return None
return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))

View File

@@ -0,0 +1,288 @@
# WARNING: this file is auto-generated by 'async_to_sync.py'
# from the original file 'cursor_async.py'
# DO NOT CHANGE! Change the original file instead.
"""
Psycopg Cursor object.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from types import TracebackType
from typing import Any, Iterator, Iterable, TYPE_CHECKING, overload
from contextlib import contextmanager
from . import pq
from . import errors as e
from .abc import Query, Params
from .copy import Copy, Writer
from .rows import Row, RowMaker, RowFactory
from ._compat import Self
from ._pipeline import Pipeline
from ._cursor_base import BaseCursor
if TYPE_CHECKING:
from .connection import Connection
ACTIVE = pq.TransactionStatus.ACTIVE
class Cursor(BaseCursor["Connection[Any]", Row]):
__module__ = "psycopg"
__slots__ = ()
@overload
def __init__(self, connection: Connection[Row]): ...
@overload
def __init__(
self, connection: Connection[Any], *, row_factory: RowFactory[Row]
): ...
def __init__(
self, connection: Connection[Any], *, row_factory: RowFactory[Row] | None = None
):
super().__init__(connection)
self._row_factory = row_factory or connection.row_factory
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the current cursor and free associated resources.
"""
self._close()
@property
def row_factory(self) -> RowFactory[Row]:
"""Writable attribute to control how result rows are formed."""
return self._row_factory
@row_factory.setter
def row_factory(self, row_factory: RowFactory[Row]) -> None:
self._row_factory = row_factory
if self.pgresult:
self._make_row = row_factory(self)
def _make_row_maker(self) -> RowMaker[Row]:
return self._row_factory(self)
def execute(
self,
query: Query,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool | None = None,
) -> Self:
"""
Execute a query or command to the database.
"""
try:
with self._conn.lock:
self._conn.wait(
self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
def executemany(
self, query: Query, params_seq: Iterable[Params], *, returning: bool = False
) -> None:
"""
Execute the same command with a sequence of input data.
"""
try:
if Pipeline.is_supported():
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
with self._conn.lock:
p = self._conn._pipeline
if p:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
# Otherwise, make a new one
if not p:
with self._conn.pipeline(), self._conn.lock:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
else:
with self._conn.lock:
self._conn.wait(
self._executemany_gen_no_pipeline(query, params_seq, returning)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def stream(
self,
query: Query,
params: Params | None = None,
*,
binary: bool | None = None,
size: int = 1,
) -> Iterator[Row]:
"""
Iterate row-by-row on a result from the database.
:param size: if greater than 1, results will be retrieved by chunks of
this size from the server (but still yielded row-by-row); this is only
available from version 17 of the libpq.
"""
if self._pgconn.pipeline_status:
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
with self._conn.lock:
try:
self._conn.wait(
self._stream_send_gen(query, params, binary=binary, size=size)
)
first = True
while self._conn.wait(self._stream_fetchone_gen(first)):
for pos in range(size):
rec = self._tx.load_row(pos, self._make_row)
if rec is None:
break
yield rec
first = False
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
finally:
if self._pgconn.transaction_status == ACTIVE:
# Try to cancel the query, then consume the results
# already received.
self._conn._try_cancel()
try:
while self._conn.wait(self._stream_fetchone_gen(first=False)):
pass
except Exception:
pass
# Try to get out of ACTIVE state. Just do a single attempt, which
# should work to recover from an error or query cancelled.
try:
self._conn.wait(self._stream_fetchone_gen(first=False))
except Exception:
pass
def fetchone(self) -> Row | None:
"""
Return the next record from the current recordset.
Return `!None` the recordset is finished.
:rtype: Row | None, with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
self._pos += 1
return record
def fetchmany(self, size: int = 0) -> list[Row]:
"""
Return the next `!size` records from the current recordset.
`!size` default to `!self.arraysize` if not specified.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
if not size:
size = self.arraysize
records = self._tx.load_rows(
self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
)
self._pos += len(records)
return records
def fetchall(self) -> list[Row]:
"""
Return all the remaining records from the current recordset.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
def __iter__(self) -> Iterator[Row]:
self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Row | None:
return self._tx.load_row(pos, self._make_row)
while True:
row = load(self._pos)
if row is None:
break
self._pos += 1
yield row
def scroll(self, value: int, mode: str = "relative") -> None:
"""
Move the cursor in the result set to a new position according to mode.
If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
the current position in the result set; if set to ``'absolute'``,
`!value` states an absolute target position.
Raise `!IndexError` in case a scroll operation would leave the result
set. In this case the position will not change.
"""
self._fetch_pipeline()
self._scroll(value, mode)
@contextmanager
def copy(
self,
statement: Query,
params: Params | None = None,
*,
writer: Writer | None = None,
) -> Iterator[Copy]:
"""
Initiate a :sql:`COPY` operation and return an object to manage it.
"""
try:
with self._conn.lock:
self._conn.wait(self._start_copy_gen(statement, params))
with Copy(self, writer=writer) as copy:
yield copy
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# If a fresher result has been set on the cursor by the Copy object,
# read its properties (especially rowcount).
self._select_current_result(0)
def _fetch_pipeline(self) -> None:
if (
self._execmany_returning is not False
and (not self.pgresult)
and self._conn._pipeline
):
with self._conn.lock:
self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))

View File

@@ -0,0 +1,298 @@
"""
Psycopg AsyncCursor object.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
from types import TracebackType
from typing import Any, AsyncIterator, Iterable, TYPE_CHECKING, overload
from contextlib import asynccontextmanager
from . import pq
from . import errors as e
from .abc import Query, Params
from .copy import AsyncCopy, AsyncWriter
from .rows import Row, RowMaker, AsyncRowFactory
from ._compat import Self
from ._pipeline import Pipeline
from ._cursor_base import BaseCursor
if TYPE_CHECKING:
from .connection_async import AsyncConnection
ACTIVE = pq.TransactionStatus.ACTIVE
class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
__module__ = "psycopg"
__slots__ = ()
@overload
def __init__(self, connection: AsyncConnection[Row]): ...
@overload
def __init__(
self, connection: AsyncConnection[Any], *, row_factory: AsyncRowFactory[Row]
): ...
def __init__(
self,
connection: AsyncConnection[Any],
*,
row_factory: AsyncRowFactory[Row] | None = None,
):
super().__init__(connection)
self._row_factory = row_factory or connection.row_factory
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the current cursor and free associated resources.
"""
self._close()
@property
def row_factory(self) -> AsyncRowFactory[Row]:
"""Writable attribute to control how result rows are formed."""
return self._row_factory
@row_factory.setter
def row_factory(self, row_factory: AsyncRowFactory[Row]) -> None:
self._row_factory = row_factory
if self.pgresult:
self._make_row = row_factory(self)
def _make_row_maker(self) -> RowMaker[Row]:
return self._row_factory(self)
async def execute(
self,
query: Query,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool | None = None,
) -> Self:
"""
Execute a query or command to the database.
"""
try:
async with self._conn.lock:
await self._conn.wait(
self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
return self
async def executemany(
self,
query: Query,
params_seq: Iterable[Params],
*,
returning: bool = False,
) -> None:
"""
Execute the same command with a sequence of input data.
"""
try:
if Pipeline.is_supported():
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
async with self._conn.lock:
p = self._conn._pipeline
if p:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
# Otherwise, make a new one
if not p:
async with self._conn.pipeline(), self._conn.lock:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
else:
async with self._conn.lock:
await self._conn.wait(
self._executemany_gen_no_pipeline(query, params_seq, returning)
)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
async def stream(
self,
query: Query,
params: Params | None = None,
*,
binary: bool | None = None,
size: int = 1,
) -> AsyncIterator[Row]:
"""
Iterate row-by-row on a result from the database.
:param size: if greater than 1, results will be retrieved by chunks of
this size from the server (but still yielded row-by-row); this is only
available from version 17 of the libpq.
"""
if self._pgconn.pipeline_status:
raise e.ProgrammingError("stream() cannot be used in pipeline mode")
async with self._conn.lock:
try:
await self._conn.wait(
self._stream_send_gen(query, params, binary=binary, size=size)
)
first = True
while await self._conn.wait(self._stream_fetchone_gen(first)):
for pos in range(size):
rec = self._tx.load_row(pos, self._make_row)
if rec is None:
break
yield rec
first = False
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
finally:
if self._pgconn.transaction_status == ACTIVE:
# Try to cancel the query, then consume the results
# already received.
await self._conn._try_cancel()
try:
while await self._conn.wait(
self._stream_fetchone_gen(first=False)
):
pass
except Exception:
pass
# Try to get out of ACTIVE state. Just do a single attempt, which
# should work to recover from an error or query cancelled.
try:
await self._conn.wait(self._stream_fetchone_gen(first=False))
except Exception:
pass
async def fetchone(self) -> Row | None:
"""
Return the next record from the current recordset.
Return `!None` the recordset is finished.
:rtype: Row | None, with Row defined by `row_factory`
"""
await self._fetch_pipeline()
self._check_result_for_fetch()
record = self._tx.load_row(self._pos, self._make_row)
if record is not None:
self._pos += 1
return record
async def fetchmany(self, size: int = 0) -> list[Row]:
"""
Return the next `!size` records from the current recordset.
`!size` default to `!self.arraysize` if not specified.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
if not size:
size = self.arraysize
records = self._tx.load_rows(
self._pos,
min(self._pos + size, self.pgresult.ntuples),
self._make_row,
)
self._pos += len(records)
return records
async def fetchall(self) -> list[Row]:
"""
Return all the remaining records from the current recordset.
:rtype: Sequence[Row], with Row defined by `row_factory`
"""
await self._fetch_pipeline()
self._check_result_for_fetch()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
async def __aiter__(self) -> AsyncIterator[Row]:
await self._fetch_pipeline()
self._check_result_for_fetch()
def load(pos: int) -> Row | None:
return self._tx.load_row(pos, self._make_row)
while True:
row = load(self._pos)
if row is None:
break
self._pos += 1
yield row
async def scroll(self, value: int, mode: str = "relative") -> None:
"""
Move the cursor in the result set to a new position according to mode.
If `!mode` is ``'relative'`` (default), `!value` is taken as offset to
the current position in the result set; if set to ``'absolute'``,
`!value` states an absolute target position.
Raise `!IndexError` in case a scroll operation would leave the result
set. In this case the position will not change.
"""
await self._fetch_pipeline()
self._scroll(value, mode)
@asynccontextmanager
async def copy(
self,
statement: Query,
params: Params | None = None,
*,
writer: AsyncWriter | None = None,
) -> AsyncIterator[AsyncCopy]:
"""
Initiate a :sql:`COPY` operation and return an object to manage it.
"""
try:
async with self._conn.lock:
await self._conn.wait(self._start_copy_gen(statement, params))
async with AsyncCopy(self, writer=writer) as copy:
yield copy
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# If a fresher result has been set on the cursor by the Copy object,
# read its properties (especially rowcount).
self._select_current_result(0)
async def _fetch_pipeline(self) -> None:
if (
self._execmany_returning is not False
and not self.pgresult
and self._conn._pipeline
):
async with self._conn.lock:
await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))

View File

@@ -0,0 +1,134 @@
"""
Compatibility objects with DBAPI 2.0
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import time
import datetime as dt
from math import floor
from typing import Any, Sequence
from . import _oids
from .abc import AdaptContext, Buffer
from .types.string import BytesDumper, BytesBinaryDumper
class DBAPITypeObject:
def __init__(self, name: str, oids: Sequence[int]):
self.name = name
self.values = tuple(oids)
def __repr__(self) -> str:
return f"psycopg.{self.name}"
def __eq__(self, other: Any) -> bool:
if isinstance(other, int):
return other in self.values
else:
return NotImplemented
def __ne__(self, other: Any) -> bool:
if isinstance(other, int):
return other not in self.values
else:
return NotImplemented
BINARY = DBAPITypeObject("BINARY", (_oids.BYTEA_OID,))
DATETIME = DBAPITypeObject(
"DATETIME",
(
_oids.TIMESTAMP_OID,
_oids.TIMESTAMPTZ_OID,
_oids.DATE_OID,
_oids.TIME_OID,
_oids.TIMETZ_OID,
_oids.INTERVAL_OID,
),
)
NUMBER = DBAPITypeObject(
"NUMBER",
(
_oids.INT2_OID,
_oids.INT4_OID,
_oids.INT8_OID,
_oids.FLOAT4_OID,
_oids.FLOAT8_OID,
_oids.NUMERIC_OID,
),
)
ROWID = DBAPITypeObject("ROWID", (_oids.OID_OID,))
STRING = DBAPITypeObject(
"STRING", (_oids.TEXT_OID, _oids.VARCHAR_OID, _oids.BPCHAR_OID)
)
class Binary:
def __init__(self, obj: Any):
self.obj = obj
def __repr__(self) -> str:
sobj = repr(self.obj)
if len(sobj) > 40:
sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
return f"{self.__class__.__name__}({sobj})"
class BinaryBinaryDumper(BytesBinaryDumper):
def dump(self, obj: Buffer | Binary) -> Buffer | None:
if isinstance(obj, Binary):
return super().dump(obj.obj)
else:
return super().dump(obj)
class BinaryTextDumper(BytesDumper):
def dump(self, obj: Buffer | Binary) -> Buffer | None:
if isinstance(obj, Binary):
return super().dump(obj.obj)
else:
return super().dump(obj)
def Date(year: int, month: int, day: int) -> dt.date:
return dt.date(year, month, day)
def DateFromTicks(ticks: float) -> dt.date:
return TimestampFromTicks(ticks).date()
def Time(hour: int, minute: int, second: int) -> dt.time:
return dt.time(hour, minute, second)
def TimeFromTicks(ticks: float) -> dt.time:
return TimestampFromTicks(ticks).time()
def Timestamp(
year: int, month: int, day: int, hour: int, minute: int, second: int
) -> dt.datetime:
return dt.datetime(year, month, day, hour, minute, second)
def TimestampFromTicks(ticks: float) -> dt.datetime:
secs = floor(ticks)
frac = ticks - secs
t = time.localtime(ticks)
tzinfo = dt.timezone(dt.timedelta(seconds=t.tm_gmtoff))
rv = dt.datetime(*t[:6], round(frac * 1_000_000), tzinfo=tzinfo)
return rv
def register_dbapi20_adapters(context: AdaptContext) -> None:
adapters = context.adapters
adapters.register_dumper(Binary, BinaryTextDumper)
adapters.register_dumper(Binary, BinaryBinaryDumper)
# Make them also the default dumpers when dumping by bytea oid
adapters.register_dumper(None, BinaryTextDumper)
adapters.register_dumper(None, BinaryBinaryDumper)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,407 @@
"""
Generators implementing communication protocols with the libpq
Certain operations (connection, querying) are an interleave of libpq calls and
waiting for the socket to be ready. This module contains the code to execute
the operations, yielding a polling state whenever there is to wait. The
functions in the `waiting` module are the ones who wait more or less
cooperatively for the socket to be ready and make these generators continue.
These generators yield `Wait` objects whenever an operation would block. These
generators assume the connection fileno will not change. In case of the
connection function, where the fileno may change, the generators yield pairs
(fileno, `Wait`).
The generator can be restarted sending the appropriate `Ready` state when the
file descriptor is ready. If a None value is sent, it means that the wait
function timed out without any file descriptor becoming ready; in this case the
generator should probably yield the same value again in order to wait more.
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import logging
from time import monotonic
from . import pq
from . import errors as e
from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
from .pq.abc import PGcancelConn, PGconn, PGresult
from .waiting import Wait, Ready
from ._compat import Deque
from ._cmodule import _psycopg
from ._encodings import conninfo_encoding
OK = pq.ConnStatus.OK
BAD = pq.ConnStatus.BAD
POLL_OK = pq.PollingStatus.OK
POLL_READING = pq.PollingStatus.READING
POLL_WRITING = pq.PollingStatus.WRITING
POLL_FAILED = pq.PollingStatus.FAILED
COMMAND_OK = pq.ExecStatus.COMMAND_OK
COPY_OUT = pq.ExecStatus.COPY_OUT
COPY_IN = pq.ExecStatus.COPY_IN
COPY_BOTH = pq.ExecStatus.COPY_BOTH
PIPELINE_SYNC = pq.ExecStatus.PIPELINE_SYNC
WAIT_R = Wait.R
WAIT_W = Wait.W
WAIT_RW = Wait.RW
READY_R = Ready.R
READY_W = Ready.W
READY_RW = Ready.RW
logger = logging.getLogger(__name__)
def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]:
"""
Generator to create a database connection without blocking.
"""
deadline = monotonic() + timeout if timeout else 0.0
conn = pq.PGconn.connect_start(conninfo.encode())
while True:
if conn.status == BAD:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
f"connection is bad: {conn.get_error_message(encoding)}",
pgconn=conn,
)
status = conn.connect_poll()
if status == POLL_READING or status == POLL_WRITING:
wait = WAIT_R if status == POLL_READING else WAIT_W
while True:
ready = yield conn.socket, wait
if deadline and monotonic() > deadline:
raise e.ConnectionTimeout("connection timeout expired")
if ready:
break
elif status == POLL_OK:
break
elif status == POLL_FAILED:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
f"connection failed: {conn.get_error_message(encoding)}",
pgconn=e.finish_pgconn(conn),
)
else:
raise e.InternalError(
f"unexpected poll status: {status}", pgconn=e.finish_pgconn(conn)
)
conn.nonblocking = 1
return conn
def _cancel(cancel_conn: PGcancelConn, *, timeout: float = 0.0) -> PQGenConn[None]:
deadline = monotonic() + timeout if timeout else 0.0
while True:
if deadline and monotonic() > deadline:
raise e.CancellationTimeout("cancellation timeout expired")
status = cancel_conn.poll()
if status == POLL_OK:
break
elif status == POLL_READING:
yield cancel_conn.socket, WAIT_R
elif status == POLL_WRITING:
yield cancel_conn.socket, WAIT_W
elif status == POLL_FAILED:
raise e.OperationalError(
f"cancellation failed: {cancel_conn.get_error_message()}"
)
else:
raise e.InternalError(f"unexpected poll status: {status}")
def _execute(pgconn: PGconn) -> PQGen[list[PGresult]]:
"""
Generator sending a query and returning results without blocking.
The query must have already been sent using `pgconn.send_query()` or
similar. Flush the query and then return the result using nonblocking
functions.
Return the list of results returned by the database (whether success
or error).
"""
yield from _send(pgconn)
rv = yield from _fetch_many(pgconn)
return rv
def _send(pgconn: PGconn) -> PQGen[None]:
"""
Generator to send a query to the server without blocking.
The query must have already been sent using `pgconn.send_query()` or
similar. Flush the query and then return the result using nonblocking
functions.
After this generator has finished you may want to cycle using `fetch()`
to retrieve the results available.
"""
while True:
f = pgconn.flush()
if f == 0:
break
while True:
ready = yield WAIT_RW
if ready:
break
if ready & READY_R:
# This call may read notifies: they will be saved in the
# PGconn buffer and passed to Python later, in `fetch()`.
pgconn.consume_input()
def _fetch_many(pgconn: PGconn) -> PQGen[list[PGresult]]:
"""
Generator retrieving results from the database without blocking.
The query must have already been sent to the server, so pgconn.flush() has
already returned 0.
Return the list of results returned by the database (whether success
or error).
"""
results: list[PGresult] = []
while True:
res = yield from _fetch(pgconn)
if not res:
break
results.append(res)
status = res.status
if status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
# After entering copy mode the libpq will create a phony result
# for every request so let's break the endless loop.
break
if status == PIPELINE_SYNC:
# PIPELINE_SYNC is not followed by a NULL, but we return it alone
# similarly to other result sets.
assert len(results) == 1, results
break
return results
def _fetch(pgconn: PGconn) -> PQGen[PGresult | None]:
"""
Generator retrieving a single result from the database without blocking.
The query must have already been sent to the server, so pgconn.flush() has
already returned 0.
Return a result from the database (whether success or error).
"""
if pgconn.is_busy():
while True:
ready = yield WAIT_R
if ready:
break
while True:
pgconn.consume_input()
if not pgconn.is_busy():
break
while True:
ready = yield WAIT_R
if ready:
break
_consume_notifies(pgconn)
return pgconn.get_result()
def _pipeline_communicate(
pgconn: PGconn, commands: Deque[PipelineCommand]
) -> PQGen[list[list[PGresult]]]:
"""Generator to send queries from a connection in pipeline mode while also
receiving results.
Return a list results, including single PIPELINE_SYNC elements.
"""
results = []
while True:
while True:
ready = yield WAIT_RW
if ready:
break
if ready & READY_R:
pgconn.consume_input()
_consume_notifies(pgconn)
res: list[PGresult] = []
while not pgconn.is_busy():
r = pgconn.get_result()
if r is None:
if not res:
break
results.append(res)
res = []
else:
status = r.status
if status == PIPELINE_SYNC:
assert not res
results.append([r])
elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
# This shouldn't happen, but insisting hard enough, it will.
# For instance, in test_executemany_badquery(), with the COPY
# statement and the AsyncClientCursor, which disables
# prepared statements).
# Bail out from the resulting infinite loop.
raise e.NotSupportedError(
"COPY cannot be used in pipeline mode"
)
else:
res.append(r)
if ready & READY_W:
pgconn.flush()
if not commands:
break
commands.popleft()()
return results
def _consume_notifies(pgconn: PGconn) -> None:
# Consume notifies
while True:
n = pgconn.notifies()
if not n:
break
if pgconn.notify_handler:
pgconn.notify_handler(n)
def notifies(pgconn: PGconn) -> PQGen[list[pq.PGnotify]]:
yield WAIT_R
pgconn.consume_input()
ns = []
while True:
n = pgconn.notifies()
if n:
ns.append(n)
else:
break
return ns
def copy_from(pgconn: PGconn) -> PQGen[memoryview | PGresult]:
while True:
nbytes, data = pgconn.get_copy_data(1)
if nbytes != 0:
break
# would block
while True:
ready = yield WAIT_R
if ready:
break
pgconn.consume_input()
if nbytes > 0:
# some data
return data
# Retrieve the final result of copy
results = yield from _fetch_many(pgconn)
if len(results) > 1:
# TODO: too brutal? Copy worked.
raise e.ProgrammingError("you cannot mix COPY with other operations")
result = results[0]
if result.status != COMMAND_OK:
raise e.error_from_result(result, encoding=pgconn._encoding)
return result
def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]:
# Retry enqueuing data until successful.
#
# WARNING! This can cause an infinite loop if the buffer is too large. (see
# ticket #255). We avoid it in the Copy object by splitting a large buffer
# into smaller ones. We prefer to do it there instead of here in order to
# do it upstream the queue decoupling the writer task from the producer one.
while pgconn.put_copy_data(buffer) == 0:
while True:
ready = yield WAIT_W
if ready:
break
# Flushing often has a good effect on macOS because memcpy operations
# seem expensive on this platform so accumulating a large buffer has a
# bad effect (see #745).
if flush:
# Repeat until it the message is flushed to the server
while True:
while True:
ready = yield WAIT_W
if ready:
break
f = pgconn.flush()
if f == 0:
break
def copy_end(pgconn: PGconn, error: bytes | None) -> PQGen[PGresult]:
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
while True:
ready = yield WAIT_W
if ready:
break
# Repeat until it the message is flushed to the server
while True:
while True:
ready = yield WAIT_W
if ready:
break
f = pgconn.flush()
if f == 0:
break
# Retrieve the final result of copy
(result,) = yield from _fetch_many(pgconn)
if result.status != COMMAND_OK:
raise e.error_from_result(result, encoding=pgconn._encoding)
return result
# Override functions with fast versions if available
if _psycopg:
connect = _psycopg.connect
cancel = _psycopg.cancel
execute = _psycopg.execute
send = _psycopg.send
fetch_many = _psycopg.fetch_many
fetch = _psycopg.fetch
pipeline_communicate = _psycopg.pipeline_communicate
else:
connect = _connect
cancel = _cancel
execute = _execute
send = _send
fetch_many = _fetch_many
fetch = _fetch
pipeline_communicate = _pipeline_communicate

Some files were not shown because too many files have changed in this diff Show More