Source code for zensols.db.bean

"""Basic CRUD utility classes

"""
__author__ = 'Paul Landes'
from typing import Dict, Any, Tuple, Union, Callable, Iterable, Type, List
from dataclasses import dataclass, field, fields
from abc import ABCMeta
import logging
from pathlib import Path
import pandas as pd
from zensols.persist import chunks
from zensols.db import DynamicDataParser
from .conn import connection, DBError, AbstractDbPersister

logger = logging.getLogger(__name__)


[docs] @dataclass class DbPersister(AbstractDbPersister): """CRUDs data to/from a DB-API connection. """ sql_file: Path = field(default=None) """The text file containing the SQL statements (see :class:`DynamicDataParser`). """ row_factory: Union[str, Type] = field(default='tuple') """The default method by which data is returned from ``execute_*`` methods. :see: :meth:`execute`. """ def __post_init__(self): self.parser = self._create_parser(self.sql_file) self.conn_manager.register_persister(self) def _create_parser(self, sql_file: Path) -> DynamicDataParser: return DynamicDataParser(sql_file) @property def _sql_file(self) -> Path: return self._sql_file_val @_sql_file.setter def _sql_file(self, sql_file: Path): self._sql_file_val = sql_file if hasattr(self, 'parser'): self.parser.dd_path = sql_file @property def sql_entries(self) -> Dict[str, str]: """Return a dictionary of names -> SQL statements from the SQL file. """ return self.parser.sections @property def metadata(self) -> Dict[str, str]: """Return the metadata associated with the SQL file. """ return self.parser.metadata def _create_connection(self): """Create a connection to the database. """ if logger.isEnabledFor(logging.DEBUG): logger.debug('creating connection') return self.conn_manager.create() def _dispose_connection(self, conn: Any): """Close the connection to the database. :param conn: the connection to release """ if logger.isEnabledFor(logging.DEBUG): logger.debug(f'closing connection {conn}') self.conn_manager.dispose(conn) def _check_entry(self, name: str): if name is None: raise DBError('No defined SQL entry for persist function ' + f"in SQL file '{self.sql_file}'") if len(name) == 0: raise DBError('Non-optional entry not provided ' + f"in SQL file '{self.sql_file}'") if name not in self.sql_entries: raise DBError(f"No entry '{name}' found in SQL configuration " + f"in SQL file '{self.sql_file}'") def _get_entry(self, name: str): self._check_entry(name) return self.sql_entries[name]
[docs] def execute_by_name(self, name: str, params: Tuple[Any] = (), row_factory: Union[str, Callable] = None, map_fn: Callable = None): """Just like :meth:`execute` but look up the SQL statement to execute on the database connection. The ``row_factory`` tells the method how to interpret the row data in to an object that's returned. It can be one of: * ``tuple``: tuples (the default) * ``dict``: for dictionaries * ``pandas``: for a :class:`pandas.DataFrame` * otherwise: a function or class Compare this with ``map_fn``, which transforms the data that's given to the ``row_factory``. :param name: the named SQL query in the :obj:`sql_file` :param params: the parameters given to the SQL statement (populated with ``?``) in the statement :param row_factory: ``tuple``, ``dict``, ``pandas`` or a function :param map_fn: a function that transforms row data given to the ``row_factory`` :see: :meth:`execute` """ if logger.isEnabledFor(logging.DEBUG): logger.debug(f'name: <{name}>, params: {params}') self._check_entry(name) sql = self.sql_entries[name] return self.execute(sql, params, row_factory, map_fn)
[docs] def execute_singleton_by_name(self, *args, **kwargs): """Just like :meth:`execute_by_name` except return only the first item or ``None`` if no results. """ res = self.execute_by_name(*args, **kwargs) if len(res) > 0: return res[0]
@connection() def execute_sql_no_read(self, conn: Any, sql: str, params: Tuple[Any] = ()) -> int: """Execute SQL and return the database level information such as row IDs rather than the results of a query. Use this when inserting data to get a row ID. """ return self.conn_manager.execute_no_read(conn, sql, params) @connection() def execute_no_read(self, conn: Any, entry_name: str, params: Tuple[Any] = ()) -> int: """Just like :meth:`execute_by_name`, but return database level information such as row IDs rather than the results of a query. Use this when inserting data to get a row ID. :param entry_name: the key in the SQL file whose value is used as the statement :param capture_rowid: if ``True``, return the last row ID from the cursor :see: :meth:`execute_sql_no_read` """ self._check_entry(entry_name) sql = self.sql_entries[entry_name] return self.conn_manager.execute_no_read(conn, sql, params)
# keep the dataclass semantics, but allow for a setter DbPersister.sql_file = DbPersister._sql_file
[docs] @dataclass class Bean(object, metaclass=ABCMeta): """A container class like a Java *bean*. """
[docs] def get_attr_names(self) -> Tuple[str]: """Return a list of string attribute names. """ return tuple(map(lambda f: f.name, fields(self)))
[docs] def get_attrs(self) -> Dict[str, Any]: """Return a dict of attributes that are meant to be persisted. """ return {n: getattr(self, n) for n in self.get_attr_names()}
[docs] def get_row(self) -> Tuple[Any]: """Return a row of data meant to be printed. This includes the unique ID of the bean (see :meth:`get_insert_row`). """ return tuple(map(lambda x: getattr(self, x), self.get_attr_names()))
[docs] def get_insert_row(self) -> Tuple[Any]: """Return a row of data meant to be inserted into the database. This method implementation leaves off the first attriubte assuming it contains a unique (i.e. row ID) of the object. See :meth:`get_row`. """ names = self.get_attr_names() return tuple(map(lambda x: getattr(self, x), names[1:]))
def __eq__(self, other): if other is None: return False if self is other: return True if self.__class__ != other.__class__: return False for n in self.get_attr_names(): if getattr(self, n) != getattr(other, n): return False return True def __hash__(self): vals = tuple(map(lambda n: getattr(self, n), self.get_attr_names())) return hash(vals) def __str__(self): return ', '.join(map(lambda x: f'{x}: {getattr(self, x)}', self.get_attr_names())) def __repr__(self): return self.__str__()
[docs] @dataclass class ReadOnlyBeanDbPersister(DbPersister): """A read-only persister that CRUDs data based on predefined SQL given in the configuration. The class optionally works with instances of :class:`.Bean` when :obj:`row_factory` is set to the target bean class. """ select_name: str = field(default=None) """The name of the SQL entry used to select data/class.""" select_by_id_name: str = field(default=None) """The name of the SQL entry used to select a single row by unique ID.""" select_exists_name: str = field(default=None) """The name of the SQL entry used to determine if a row exists by unique ID. """
[docs] def get(self) -> list: """Return using the SQL provided by the entry identified by :obj:`select_name`. """ return self.execute_by_name( self.select_name, row_factory=self.row_factory)
[docs] def get_by_id(self, id: int): """Return an object using it's unique ID, which is could be the row ID in SQLite. """ rows = self.execute_by_name( self.select_by_id_name, params=(id,), row_factory=self.row_factory) if len(rows) > 0: return rows[0]
[docs] def exists(self, id: int) -> bool: """Return ``True`` if there is a object with unique ID (or row ID) in the database. Otherwise return ``False``. """ if self.select_exists_name is None: return self.get_by_id(id) is not None else: cnt = self.execute_by_name( self.select_exists_name, params=(id,), row_factory='tuple') return cnt[0][0] == 1
[docs] @dataclass class InsertableBeanDbPersister(ReadOnlyBeanDbPersister): """A class that contains insert funtionality. """ insert_name: str = field(default=None) """The name of the SQL entry used to insert data/class instance."""
[docs] def insert_row(self, *row) -> int: """Insert a row in the database and return the current row ID. :param row: a sequence of data in column order of the SQL provided by the entry :obj:`insert_name` """ return self.execute_no_read(self.insert_name, params=row)
@connection() def insert_rows(self, conn: Any, rows: Iterable[Any], errors: str = 'raise', set_id_fn: Callable = None, map_fn: Callable = None) -> int: """Insert a tuple of rows in the database and return the current row ID. :param rows: a sequence of tuples of data (or an object to be transformed, see ``map_fn`` in column order of the SQL provided by the entry :obj:`insert_name` :param errors: if this is the string ``raise`` then raise an error on any exception when invoking the database execute :param set_id_fn: a callable that is given the data to be inserted and the row ID returned from the row insert as parameters :param map_fn: if not ``None``, used to transform the given row in to a tuple that is used for the insertion :return: the ``rowid`` of the last row inserted """ entry_name = self.insert_name self._check_entry(entry_name) sql = self.sql_entries[entry_name] return self.conn_manager.insert_rows( conn, sql, rows, errors, set_id_fn, map_fn) @connection() def insert_dataframe(self, conn: Any, df: pd.DataFrame, errors: str = 'raise', set_id_fn: Callable = None, map_fn: Callable = None, chunk_size: int = 100) -> int: """Like :meth:`insert_rows` but the data is taken a Pandas dataframe. :param df: the dataframe from which the rows are drawn :param set_id_fn: a callable that is given the data to be inserted and the row ID returned from the row insert as parameters :param map_fn: if not ``None``, used to transform the given row in to a tuple that is used for the insertion :param chuck_size: the number of rows inserted at a time, so the number of interactions with the database are at most the row count of the dataframe / ``chunk_size``x :return: the ``rowid`` of the last row inserted """ entry_name: str = self.insert_name sql: str = self._get_entry(entry_name) row_id: int rows: List[Tuple[Any, ...]] for rows in chunks(df.itertuples(name=None, index=False), chunk_size): row_id = self.conn_manager.insert_rows( conn, sql, rows, errors, set_id_fn, map_fn) return row_id def _get_insert_row(self, bean: Bean) -> Tuple[Any]: """Factory method to return the bean's insert row parameters.""" return bean.get_insert_row()
[docs] def insert(self, bean: Bean) -> int: """Insert a bean using the order of the values given in :meth:`Bean.get_insert_row` as that of the SQL defined with entry :obj:`insert_name` given in the initializer. """ row = self._get_insert_row(bean) if logger.isEnabledFor(logging.DEBUG): logger.debug(f'inserting row: {row}') curid = self.insert_row(*row) bean.id = curid return curid
[docs] def insert_beans(self, beans: Iterable[Any], errors: str = 'raise') -> int: """Insert a bean using the order of the values given in :meth:`Bean.get_insert_row` as that of the SQL defined with entry :obj:`insert_name` given in the initializer. """ def map_fn(bean): return self._get_insert_row(bean) def set_id_fn(bean, id): pass return self.insert_rows(beans, errors, set_id_fn, map_fn)
[docs] @dataclass class UpdatableBeanDbPersister(InsertableBeanDbPersister): """A class that contains the remaining CRUD funtionality the super class doesn't have. """ update_name: str = field(default=None) """The name of the SQL entry used to update data/class instance(s).""" delete_name: str = field(default=None) """The name of the SQL entry used to delete data/class instance(s)."""
[docs] def update_row(self, *row: Tuple[Any]) -> int: """Update a row using the values of the row with the current unique ID as the first element in ``*rows``. """ where_row = (*row[1:], row[0]) return self.execute_no_read(self.update_name, params=where_row)
[docs] def update(self, bean: Bean) -> int: """Update a a bean that using the ``id`` attribute and its attributes as values. """ return self.update_row(*bean.get_row())
[docs] def delete(self, id) -> int: """Delete a row by ID. """ return self.execute_no_read(self.delete_name, params=(id,))
[docs] @dataclass class BeanDbPersister(UpdatableBeanDbPersister): """A class that contains the remaining CRUD funtionality the super class doesn't have. """ keys_name: str = field(default=None) """The name of the SQL entry used to fetch all keys.""" count_name: str = field(default=None) """The name of the SQL entry used to get a row count."""
[docs] def get_keys(self) -> Iterable[Any]: """Return the unique keys from the bean table. """ keys = self.execute_by_name(self.keys_name, row_factory='tuple') return map(lambda x: x[0], keys)
[docs] def get_count(self) -> int: """Return the number of rows in the bean table. """ if self.count_name is not None: cnt = self.execute_by_name(self.count_name, row_factory='tuple') return cnt[0][0] else: # SQLite has a bug that returns one row with all null values return sum(1 for _ in self.get_keys())