""" psycopg copy support """ # Copyright (C) 2020 The Psycopg Team from __future__ import annotations import re import sys import struct from abc import ABC, abstractmethod from typing import Any, Generic, Match, Sequence, TYPE_CHECKING from . import pq from . import adapt from . import errors as e from .abc import Buffer, ConnectionType, PQGen, Transformer from .pq.misc import connection_summary from ._cmodule import _psycopg from .generators import copy_from if TYPE_CHECKING: from ._cursor_base import BaseCursor PY_TEXT = adapt.PyFormat.TEXT PY_BINARY = adapt.PyFormat.BINARY TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY COPY_IN = pq.ExecStatus.COPY_IN COPY_OUT = pq.ExecStatus.COPY_OUT # Size of data to accumulate before sending it down the network. We fill a # buffer this size field by field, and when it passes the threshold size # we ship it, so it may end up being bigger than this. BUFFER_SIZE = 32 * 1024 # Maximum data size we want to queue to send to the libpq copy. Sending a # buffer too big to be handled can cause an infinite loop in the libpq # (#255) so we want to split it in more digestable chunks. MAX_BUFFER_SIZE = 4 * BUFFER_SIZE # Note: making this buffer too large, e.g. # MAX_BUFFER_SIZE = 1024 * 1024 # makes operations *way* slower! Probably triggering some quadraticity # in the libpq memory management and data sending. # Max size of the write queue of buffers. More than that copy will block # Each buffer should be around BUFFER_SIZE size. QUEUE_SIZE = 1024 # On certain systems, memmove seems particularly slow and flushing often is # more performing than accumulating a larger buffer. See #746 for details. PREFER_FLUSH = sys.platform == "darwin" class BaseCopy(Generic[ConnectionType]): """ Base implementation for the copy user interface. Two subclasses expose real methods with the sync/async differences. The difference between the text and binary format is managed by two different `Formatter` subclasses. Writing (the I/O part) is implemented in the subclasses by a `Writer` or `AsyncWriter` instance. Normally writing implies sending copy data to a database, but a different writer might be chosen, e.g. to stream data into a file for later use. """ formatter: Formatter def __init__( self, cursor: BaseCursor[ConnectionType, Any], *, binary: bool | None = None, ): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn result = cursor.pgresult if result: self._direction = result.status if self._direction != COPY_IN and self._direction != COPY_OUT: raise e.ProgrammingError( "the cursor should have performed a COPY operation;" f" its status is {pq.ExecStatus(self._direction).name} instead" ) else: self._direction = COPY_IN if binary is None: binary = bool(result and result.binary_tuples) tx: Transformer = getattr(cursor, "_tx", None) or adapt.Transformer(cursor) if binary: self.formatter = BinaryFormatter(tx) else: self.formatter = TextFormatter(tx, encoding=self._pgconn._encoding) self._finished = False def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" info = connection_summary(self._pgconn) return f"<{cls} {info} at 0x{id(self):x}>" def _enter(self) -> None: if self._finished: raise TypeError("copy blocks can be used only once") def set_types(self, types: Sequence[int | str]) -> None: """ Set the types expected in a COPY operation. The types must be specified as a sequence of oid or PostgreSQL type names (e.g. ``int4``, ``timestamptz[]``). This operation overcomes the lack of metadata returned by PostgreSQL when a COPY operation begins: - On :sql:`COPY TO`, `!set_types()` allows to specify what types the operation returns. If `!set_types()` is not used, the data will be returned as unparsed strings or bytes instead of Python objects. - On :sql:`COPY FROM`, `!set_types()` allows to choose what type the database expects. This is especially useful in binary copy, because PostgreSQL will apply no cast rule. """ registry = self.cursor.adapters.types oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types] if self._direction == COPY_IN: self.formatter.transformer.set_dumper_types(oids, self.formatter.format) else: self.formatter.transformer.set_loader_types(oids, self.formatter.format) # High level copy protocol generators (state change of the Copy object) def _read_gen(self) -> PQGen[Buffer]: if self._finished: return memoryview(b"") res = yield from copy_from(self._pgconn) if isinstance(res, memoryview): return res # res is the final PGresult self._finished = True # This result is a COMMAND_OK which has info about the number of rows # returned, but not about the columns, which is instead an information # that was received on the COPY_OUT result at the beginning of COPY. # So, don't replace the results in the cursor, just update the rowcount. nrows = res.command_tuples self.cursor._rowcount = nrows if nrows is not None else -1 return memoryview(b"") def _read_row_gen(self) -> PQGen[tuple[Any, ...] | None]: data = yield from self._read_gen() if not data: return None row = self.formatter.parse_row(data) if row is None: # Get the final result to finish the copy operation yield from self._read_gen() self._finished = True return None return row def _end_copy_out_gen(self) -> PQGen[None]: try: while (yield from self._read_gen()): pass except e.QueryCanceled: pass class Formatter(ABC): """ A class which understand a copy format (text, binary). """ format: pq.Format def __init__(self, transformer: Transformer): self.transformer = transformer self._write_buffer = bytearray() self._row_mode = False # true if the user is using write_row() @abstractmethod def parse_row(self, data: Buffer) -> tuple[Any, ...] | None: ... @abstractmethod def write(self, buffer: Buffer | str) -> Buffer: ... @abstractmethod def write_row(self, row: Sequence[Any]) -> Buffer: ... @abstractmethod def end(self) -> Buffer: ... class TextFormatter(Formatter): format = TEXT def __init__(self, transformer: Transformer, encoding: str = "utf-8"): super().__init__(transformer) self._encoding = encoding def parse_row(self, data: Buffer) -> tuple[Any, ...] | None: if data: return parse_row_text(data, self.transformer) else: return None def write(self, buffer: Buffer | str) -> Buffer: data = self._ensure_bytes(buffer) self._signature_sent = True return data def write_row(self, row: Sequence[Any]) -> Buffer: # Note down that we are writing in row mode: it means we will have # to take care of the end-of-copy marker too self._row_mode = True format_row_text(row, self.transformer, self._write_buffer) if len(self._write_buffer) > BUFFER_SIZE: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer else: return b"" def end(self) -> Buffer: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer def _ensure_bytes(self, data: Buffer | str) -> Buffer: if isinstance(data, str): return data.encode(self._encoding) else: # Assume, for simplicity, that the user is not passing stupid # things to the write function. If that's the case, things # will fail downstream. return data class BinaryFormatter(Formatter): format = BINARY def __init__(self, transformer: Transformer): super().__init__(transformer) self._signature_sent = False def parse_row(self, data: Buffer) -> tuple[Any, ...] | None: if not self._signature_sent: if data[: len(_binary_signature)] != _binary_signature: raise e.DataError( "binary copy doesn't start with the expected signature" ) self._signature_sent = True data = data[len(_binary_signature) :] elif data == _binary_trailer: return None return parse_row_binary(data, self.transformer) def write(self, buffer: Buffer | str) -> Buffer: data = self._ensure_bytes(buffer) self._signature_sent = True return data def write_row(self, row: Sequence[Any]) -> Buffer: # Note down that we are writing in row mode: it means we will have # to take care of the end-of-copy marker too self._row_mode = True if not self._signature_sent: self._write_buffer += _binary_signature self._signature_sent = True format_row_binary(row, self.transformer, self._write_buffer) if len(self._write_buffer) > BUFFER_SIZE: buffer, self._write_buffer = self._write_buffer, bytearray() return buffer else: return b"" def end(self) -> Buffer: # If we have sent no data we need to send the signature # and the trailer if not self._signature_sent: self._write_buffer += _binary_signature self._write_buffer += _binary_trailer elif self._row_mode: # if we have sent data already, we have sent the signature # too (either with the first row, or we assume that in # block mode the signature is included). # Write the trailer only if we are sending rows (with the # assumption that who is copying binary data is sending the # whole format). self._write_buffer += _binary_trailer buffer, self._write_buffer = self._write_buffer, bytearray() return buffer def _ensure_bytes(self, data: Buffer | str) -> Buffer: if isinstance(data, str): raise TypeError("cannot copy str data in binary mode: use bytes instead") else: # Assume, for simplicity, that the user is not passing stupid # things to the write function. If that's the case, things # will fail downstream. return data def _format_row_text( row: Sequence[Any], tx: Transformer, out: bytearray | None = None ) -> bytearray: """Convert a row of objects to the data to send for copy.""" if out is None: out = bytearray() if not row: out += b"\n" return out adapted = tx.dump_sequence(row, [PY_TEXT] * len(row)) for b in adapted: out += _dump_re.sub(_dump_sub, b) if b is not None else rb"\N" out += b"\t" out[-1:] = b"\n" return out def _format_row_binary( row: Sequence[Any], tx: Transformer, out: bytearray | None = None ) -> bytearray: """Convert a row of objects to the data to send for binary copy.""" if out is None: out = bytearray() out += _pack_int2(len(row)) adapted = tx.dump_sequence(row, [PY_BINARY] * len(row)) for b in adapted: if b is not None: out += _pack_int4(len(b)) out += b else: out += _binary_null return out def _parse_row_text(data: Buffer, tx: Transformer) -> tuple[Any, ...]: if not isinstance(data, bytes): data = bytes(data) fields = data.split(b"\t") fields[-1] = fields[-1][:-1] # drop \n row = [None if f == b"\\N" else _load_re.sub(_load_sub, f) for f in fields] return tx.load_sequence(row) def _parse_row_binary(data: Buffer, tx: Transformer) -> tuple[Any, ...]: row: list[Buffer | None] = [] nfields = _unpack_int2(data, 0)[0] pos = 2 for i in range(nfields): length = _unpack_int4(data, pos)[0] pos += 4 if length >= 0: row.append(data[pos : pos + length]) pos += length else: row.append(None) return tx.load_sequence(row) _pack_int2 = struct.Struct("!h").pack _pack_int4 = struct.Struct("!i").pack _unpack_int2 = struct.Struct("!h").unpack_from _unpack_int4 = struct.Struct("!i").unpack_from _binary_signature = ( b"PGCOPY\n\xff\r\n\0" # Signature b"\x00\x00\x00\x00" # flags b"\x00\x00\x00\x00" # extra length ) _binary_trailer = b"\xff\xff" _binary_null = b"\xff\xff\xff\xff" _dump_re = re.compile(b"[\b\t\n\v\f\r\\\\]") _dump_repl = { b"\b": b"\\b", b"\t": b"\\t", b"\n": b"\\n", b"\v": b"\\v", b"\f": b"\\f", b"\r": b"\\r", b"\\": b"\\\\", } def _dump_sub(m: Match[bytes], __map: dict[bytes, bytes] = _dump_repl) -> bytes: return __map[m.group(0)] _load_re = re.compile(b"\\\\[btnvfr\\\\]") _load_repl = {v: k for k, v in _dump_repl.items()} def _load_sub(m: Match[bytes], __map: dict[bytes, bytes] = _load_repl) -> bytes: return __map[m.group(0)] # Override functions with fast versions if available if _psycopg: format_row_text = _psycopg.format_row_text format_row_binary = _psycopg.format_row_binary parse_row_text = _psycopg.parse_row_text parse_row_binary = _psycopg.parse_row_binary else: format_row_text = _format_row_text format_row_binary = _format_row_binary parse_row_text = _parse_row_text parse_row_binary = _parse_row_binary