""" psycopg connection objects """ # Copyright (C) 2020 The Psycopg Team from __future__ import annotations import sys import logging from typing import Callable, Generic, NamedTuple, TYPE_CHECKING from weakref import ref, ReferenceType from warnings import warn from functools import partial from . import pq from . import errors as e from . import postgres from . import generators from .abc import PQGen, PQGenConn, Query from .sql import Composable, SQL from ._tpc import Xid from .rows import Row from .adapt import AdaptersMap from ._enums import IsolationLevel from ._compat import LiteralString, Self, TypeAlias, TypeVar from .pq.misc import connection_summary from ._pipeline import BasePipeline from ._preparing import PrepareManager from ._capabilities import capabilities from ._connection_info import ConnectionInfo if TYPE_CHECKING: from .pq.abc import PGconn, PGresult from psycopg_pool.base import BasePool # Row Type variable for Cursor (when it needs to be distinguished from the # connection's one) CursorRow = TypeVar("CursorRow") TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY OK = pq.ConnStatus.OK BAD = pq.ConnStatus.BAD COMMAND_OK = pq.ExecStatus.COMMAND_OK TUPLES_OK = pq.ExecStatus.TUPLES_OK FATAL_ERROR = pq.ExecStatus.FATAL_ERROR IDLE = pq.TransactionStatus.IDLE INTRANS = pq.TransactionStatus.INTRANS _HAS_SEND_CLOSE = capabilities.has_send_close_prepared() logger = logging.getLogger("psycopg") class Notify(NamedTuple): """An asynchronous notification received from the database.""" channel: str """The name of the channel on which the notification was received.""" payload: str """The message attached to the notification.""" pid: int """The PID of the backend process which sent the notification.""" Notify.__module__ = "psycopg" NoticeHandler: TypeAlias = Callable[[e.Diagnostic], None] NotifyHandler: TypeAlias = Callable[[Notify], None] class BaseConnection(Generic[Row]): """ Base class for different types of connections. Share common functionalities such as access to the wrapped PGconn, but allow different interfaces (sync/async). """ # DBAPI2 exposed exceptions Warning = e.Warning Error = e.Error InterfaceError = e.InterfaceError DatabaseError = e.DatabaseError DataError = e.DataError OperationalError = e.OperationalError IntegrityError = e.IntegrityError InternalError = e.InternalError ProgrammingError = e.ProgrammingError NotSupportedError = e.NotSupportedError def __init__(self, pgconn: PGconn): self.pgconn = pgconn self._autocommit = False # None, but set to a copy of the global adapters map as soon as requested. self._adapters: AdaptersMap | None = None self._notice_handlers: list[NoticeHandler] = [] self._notify_handlers: list[NotifyHandler] = [] # Number of transaction blocks currently entered self._num_transactions = 0 self._closed = False # closed by an explicit close() self._prepared: PrepareManager = PrepareManager() self._tpc: tuple[Xid, bool] | None = None # xid, prepared wself = ref(self) pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) # Attribute is only set if the connection is from a pool so we can tell # apart a connection in the pool too (when _pool = None) self._pool: BasePool | None self._pipeline: BasePipeline | None = None # Time after which the connection should be closed self._expire_at: float self._isolation_level: IsolationLevel | None = None self._read_only: bool | None = None self._deferrable: bool | None = None self._begin_statement = b"" def __del__(self) -> None: # If fails on connection we might not have this attribute yet if not hasattr(self, "pgconn"): return # Connection correctly closed if self.closed: return # Connection in a pool so terminating with the program is normal if hasattr(self, "_pool"): return warn( f"connection {self} was deleted while still open." " Please use 'with' or '.close()' to close the connection", ResourceWarning, ) def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" info = connection_summary(self.pgconn) return f"<{cls} {info} at 0x{id(self):x}>" @property def closed(self) -> bool: """`!True` if the connection is closed.""" return self.pgconn.status == BAD @property def broken(self) -> bool: """ `!True` if the connection was interrupted. A broken connection is always `closed`, but wasn't closed in a clean way, such as using `close()` or a `!with` block. """ return self.pgconn.status == BAD and not self._closed @property def autocommit(self) -> bool: """The autocommit state of the connection.""" return self._autocommit @autocommit.setter def autocommit(self, value: bool) -> None: self._set_autocommit(value) def _set_autocommit(self, value: bool) -> None: raise NotImplementedError def _set_autocommit_gen(self, value: bool) -> PQGen[None]: yield from self._check_intrans_gen("autocommit") self._autocommit = bool(value) @property def isolation_level(self) -> IsolationLevel | None: """ The isolation level of the new transactions started on the connection. """ return self._isolation_level @isolation_level.setter def isolation_level(self, value: IsolationLevel | None) -> None: self._set_isolation_level(value) def _set_isolation_level(self, value: IsolationLevel | None) -> None: raise NotImplementedError def _set_isolation_level_gen(self, value: IsolationLevel | None) -> PQGen[None]: yield from self._check_intrans_gen("isolation_level") self._isolation_level = IsolationLevel(value) if value is not None else None self._begin_statement = b"" @property def read_only(self) -> bool | None: """ The read-only state of the new transactions started on the connection. """ return self._read_only @read_only.setter def read_only(self, value: bool | None) -> None: self._set_read_only(value) def _set_read_only(self, value: bool | None) -> None: raise NotImplementedError def _set_read_only_gen(self, value: bool | None) -> PQGen[None]: yield from self._check_intrans_gen("read_only") self._read_only = bool(value) if value is not None else None self._begin_statement = b"" @property def deferrable(self) -> bool | None: """ The deferrable state of the new transactions started on the connection. """ return self._deferrable @deferrable.setter def deferrable(self, value: bool | None) -> None: self._set_deferrable(value) def _set_deferrable(self, value: bool | None) -> None: raise NotImplementedError def _set_deferrable_gen(self, value: bool | None) -> PQGen[None]: yield from self._check_intrans_gen("deferrable") self._deferrable = bool(value) if value is not None else None self._begin_statement = b"" def _check_intrans_gen(self, attribute: str) -> PQGen[None]: # Raise an exception if we are in a transaction status = self.pgconn.transaction_status if status == IDLE and self._pipeline: yield from self._pipeline._sync_gen() status = self.pgconn.transaction_status if status != IDLE: if self._num_transactions: raise e.ProgrammingError( f"can't change {attribute!r} now: " "connection.transaction() context in progress" ) else: raise e.ProgrammingError( f"can't change {attribute!r} now: " "connection in transaction status " f"{pq.TransactionStatus(status).name}" ) @property def info(self) -> ConnectionInfo: """A `ConnectionInfo` attribute to inspect connection properties.""" return ConnectionInfo(self.pgconn) @property def adapters(self) -> AdaptersMap: if not self._adapters: self._adapters = AdaptersMap(postgres.adapters) return self._adapters @property def connection(self) -> BaseConnection[Row]: # implement the AdaptContext protocol return self def fileno(self) -> int: """Return the file descriptor of the connection. This function allows to use the connection as file-like object in functions waiting for readiness, such as the ones defined in the `selectors` module. """ return self.pgconn.socket def cancel(self) -> None: """Cancel the current operation on the connection.""" if self._should_cancel(): c = self.pgconn.get_cancel() c.cancel() def _should_cancel(self) -> bool: """Check whether the current command should actually be cancelled when invoking cancel*(). """ # cancel() is a no-op if the connection is closed; # this allows to use the method as callback handler without caring # about its life. if self.closed: return False if self._tpc and self._tpc[1]: raise e.ProgrammingError( "cancel() cannot be used with a prepared two-phase transaction" ) return True def _cancel_gen(self, *, timeout: float) -> PQGenConn[None]: cancel_conn = self.pgconn.cancel_conn() cancel_conn.start() yield from generators.cancel(cancel_conn, timeout=timeout) def add_notice_handler(self, callback: NoticeHandler) -> None: """ Register a callable to be invoked when a notice message is received. :param callback: the callback to call upon message received. :type callback: Callable[[~psycopg.errors.Diagnostic], None] """ self._notice_handlers.append(callback) def remove_notice_handler(self, callback: NoticeHandler) -> None: """ Unregister a notice message callable previously registered. :param callback: the callback to remove. :type callback: Callable[[~psycopg.errors.Diagnostic], None] """ self._notice_handlers.remove(callback) @staticmethod def _notice_handler( wself: ReferenceType[BaseConnection[Row]], res: PGresult ) -> None: self = wself() if not (self and self._notice_handlers): return diag = e.Diagnostic(res, self.pgconn._encoding) for cb in self._notice_handlers: try: cb(diag) except Exception as ex: logger.exception("error processing notice callback '%s': %s", cb, ex) def add_notify_handler(self, callback: NotifyHandler) -> None: """ Register a callable to be invoked whenever a notification is received. :param callback: the callback to call upon notification received. :type callback: Callable[[~psycopg.Notify], None] """ self._notify_handlers.append(callback) def remove_notify_handler(self, callback: NotifyHandler) -> None: """ Unregister a notification callable previously registered. :param callback: the callback to remove. :type callback: Callable[[~psycopg.Notify], None] """ self._notify_handlers.remove(callback) @staticmethod def _notify_handler( wself: ReferenceType[BaseConnection[Row]], pgn: pq.PGnotify ) -> None: self = wself() if not (self and self._notify_handlers): return enc = self.pgconn._encoding n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) for cb in self._notify_handlers: cb(n) @property def prepare_threshold(self) -> int | None: """ Number of times a query is executed before it is prepared. - If it is set to 0, every query is prepared the first time it is executed. - If it is set to `!None`, prepared statements are disabled on the connection. Default value: 5 """ return self._prepared.prepare_threshold @prepare_threshold.setter def prepare_threshold(self, value: int | None) -> None: self._prepared.prepare_threshold = value @property def prepared_max(self) -> int | None: """ Maximum number of prepared statements on the connection. `!None` means no max number of prepared statements. The default value is 100. """ rv = self._prepared.prepared_max return rv if rv != sys.maxsize else None @prepared_max.setter def prepared_max(self, value: int | None) -> None: if value is None: value = sys.maxsize self._prepared.prepared_max = value # Generators to perform high-level operations on the connection # # These operations are expressed in terms of non-blocking generators # and the task of waiting when needed (when the generators yield) is left # to the connections subclass, which might wait either in blocking mode # or through asyncio. # # All these generators assume exclusive access to the connection: subclasses # should have a lock and hold it before calling and consuming them. @classmethod def _connect_gen( cls, conninfo: str = "", *, timeout: float = 0.0 ) -> PQGenConn[Self]: """Generator to connect to the database and create a new instance.""" pgconn = yield from generators.connect(conninfo, timeout=timeout) conn = cls(pgconn) return conn def _exec_command( self, command: Query, result_format: pq.Format = TEXT ) -> PQGen[PGresult | None]: """ Generator to send a command and receive the result to the backend. Only used to implement internal commands such as "commit", with eventual arguments bound client-side. The cursor can do more complex stuff. """ self._check_connection_ok() if isinstance(command, str): command = command.encode(self.pgconn._encoding) elif isinstance(command, Composable): command = command.as_bytes(self) if self._pipeline: cmd = partial( self.pgconn.send_query_params, command, None, result_format=result_format, ) self._pipeline.command_queue.append(cmd) self._pipeline.result_queue.append(None) return None # Unless needed, use the simple query protocol, e.g. to interact with # pgbouncer. In pipeline mode we always use the advanced query protocol # instead, see #350 if result_format == TEXT: self.pgconn.send_query(command) else: self.pgconn.send_query_params(command, None, result_format=result_format) result = (yield from generators.execute(self.pgconn))[-1] if result.status != COMMAND_OK and result.status != TUPLES_OK: if result.status == FATAL_ERROR: raise e.error_from_result(result, encoding=self.pgconn._encoding) else: raise e.InterfaceError( f"unexpected result {pq.ExecStatus(result.status).name}" f" from command {command.decode()!r}" ) return result def _deallocate(self, name: bytes | None) -> PQGen[None]: """ Deallocate one, or all, prepared statement in the session. ``name == None`` stands for DEALLOCATE ALL. If possible, use protocol-level commands; otherwise use SQL statements. Note that PgBouncer doesn't support DEALLOCATE name, but it supports protocol-level Close from 1.21 and DEALLOCATE ALL from 1.22. """ if name is None or not _HAS_SEND_CLOSE: stmt = b"DEALLOCATE " + name if name is not None else b"DEALLOCATE ALL" yield from self._exec_command(stmt) return self._check_connection_ok() if self._pipeline: cmd = partial( self.pgconn.send_close_prepared, name, ) self._pipeline.command_queue.append(cmd) self._pipeline.result_queue.append(None) return self.pgconn.send_close_prepared(name) result = (yield from generators.execute(self.pgconn))[-1] if result.status != COMMAND_OK: if result.status == FATAL_ERROR: raise e.error_from_result(result, encoding=self.pgconn._encoding) else: raise e.InterfaceError( f"unexpected result {pq.ExecStatus(result.status).name}" " from sending closing prepared statement message" ) def _check_connection_ok(self) -> None: if self.pgconn.status == OK: return if self.pgconn.status == BAD: raise e.OperationalError("the connection is closed") raise e.InterfaceError( "cannot execute operations: the connection is" f" in status {self.pgconn.status}" ) def _start_query(self) -> PQGen[None]: """Generator to start a transaction if necessary.""" if self._autocommit: return if self.pgconn.transaction_status != IDLE: return yield from self._exec_command(self._get_tx_start_command()) if self._pipeline: yield from self._pipeline._sync_gen() def _get_tx_start_command(self) -> bytes: if self._begin_statement: return self._begin_statement parts = [b"BEGIN"] if self.isolation_level is not None: val = IsolationLevel(self.isolation_level) parts.append(b"ISOLATION LEVEL") parts.append(val.name.replace("_", " ").encode()) if self.read_only is not None: parts.append(b"READ ONLY" if self.read_only else b"READ WRITE") if self.deferrable is not None: parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE") self._begin_statement = b" ".join(parts) return self._begin_statement def _commit_gen(self) -> PQGen[None]: """Generator implementing `Connection.commit()`.""" if self._num_transactions: raise e.ProgrammingError( "Explicit commit() forbidden within a Transaction " "context. (Transaction will be automatically committed " "on successful exit from context.)" ) if self._tpc: raise e.ProgrammingError( "commit() cannot be used during a two-phase transaction" ) if self.pgconn.transaction_status == IDLE: return yield from self._exec_command(b"COMMIT") if self._pipeline: yield from self._pipeline._sync_gen() def _rollback_gen(self) -> PQGen[None]: """Generator implementing `Connection.rollback()`.""" if self._num_transactions: raise e.ProgrammingError( "Explicit rollback() forbidden within a Transaction " "context. (Either raise Rollback() or allow " "an exception to propagate out of the context.)" ) if self._tpc: raise e.ProgrammingError( "rollback() cannot be used during a two-phase transaction" ) # Get out of a "pipeline aborted" state if self._pipeline: yield from self._pipeline._sync_gen() if self.pgconn.transaction_status == IDLE: return yield from self._exec_command(b"ROLLBACK") self._prepared.clear() yield from self._prepared.maintain_gen(self) if self._pipeline: yield from self._pipeline._sync_gen() def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: """ Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. The argument types and constraints are explained in :ref:`two-phase-commit`. The values passed to the method will be available on the returned object as the members `~Xid.format_id`, `~Xid.gtrid`, `~Xid.bqual`. """ self._check_tpc() return Xid.from_parts(format_id, gtrid, bqual) def _tpc_begin_gen(self, xid: Xid | str) -> PQGen[None]: self._check_tpc() if not isinstance(xid, Xid): xid = Xid.from_string(xid) if self.pgconn.transaction_status != IDLE: raise e.ProgrammingError( "can't start two-phase transaction: connection in status" f" {pq.TransactionStatus(self.pgconn.transaction_status).name}" ) if self._autocommit: raise e.ProgrammingError( "can't use two-phase transactions in autocommit mode" ) self._tpc = (xid, False) yield from self._exec_command(self._get_tx_start_command()) def _tpc_prepare_gen(self) -> PQGen[None]: if not self._tpc: raise e.ProgrammingError( "'tpc_prepare()' must be called inside a two-phase transaction" ) if self._tpc[1]: raise e.ProgrammingError( "'tpc_prepare()' cannot be used during a prepared two-phase transaction" ) xid = self._tpc[0] self._tpc = (xid, True) yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid))) if self._pipeline: yield from self._pipeline._sync_gen() def _tpc_finish_gen( self, action: LiteralString, xid: Xid | str | None ) -> PQGen[None]: fname = f"tpc_{action.lower()}()" if xid is None: if not self._tpc: raise e.ProgrammingError( f"{fname} without xid must must be" " called inside a two-phase transaction" ) xid = self._tpc[0] else: if self._tpc: raise e.ProgrammingError( f"{fname} with xid must must be called" " outside a two-phase transaction" ) if not isinstance(xid, Xid): xid = Xid.from_string(xid) if self._tpc and not self._tpc[1]: meth: Callable[[], PQGen[None]] meth = getattr(self, f"_{action.lower()}_gen") self._tpc = None yield from meth() else: yield from self._exec_command( SQL("{} PREPARED {}").format(SQL(action), str(xid)) ) self._tpc = None def _check_tpc(self) -> None: """Raise NotSupportedError if TPC is not supported.""" # TPC supported on every supported PostgreSQL version. pass