diff --git a/devine/commands/kv.py b/devine/commands/kv.py index 61e3f34..f237a37 100644 --- a/devine/commands/kv.py +++ b/devine/commands/kv.py @@ -89,7 +89,7 @@ def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) - log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}") try: - added = to_vault.add_keys(service_, content_keys, commit=True) + added = to_vault.add_keys(service_, content_keys) except PermissionError: log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...") continue @@ -171,7 +171,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None: for vault in vaults_: log.info(f"Adding {total_count} Content Keys to {vault}") - added_count = vault.add_keys(service, kid_keys, commit=True) + added_count = vault.add_keys(service, kid_keys) existed_count = total_count - added_count log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)") diff --git a/devine/core/utils/atomicsql.py b/devine/core/utils/atomicsql.py deleted file mode 100644 index dcee82d..0000000 --- a/devine/core/utils/atomicsql.py +++ /dev/null @@ -1,105 +0,0 @@ -""" -AtomicSQL - Race-condition and Threading safe SQL Database Interface. -Copyright (C) 2020-2023 rlaphoenix - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -This program is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU General Public License for more details. - -You should have received a copy of the GNU General Public License -along with this program. If not, see . -""" - -import os -import sqlite3 -import time -from threading import Lock -from typing import Any, Callable, Union - -import pymysql.cursors - -Connections = Union[sqlite3.Connection, pymysql.connections.Connection] -Cursors = Union[sqlite3.Cursor, pymysql.cursors.Cursor] - - -class AtomicSQL: - """ - Race-condition and Threading safe SQL Database Interface. - """ - - def __init__(self) -> None: - self.master_lock = Lock() # prevents race condition - self.db: dict[bytes, Connections] = {} # used to hold the database connections and commit changes and such - self.cursor: dict[bytes, Cursors] = {} # used to execute queries and receive results - self.session_lock: dict[bytes, Lock] = {} # like master_lock, but per-session - - def load(self, connection: Connections) -> bytes: - """ - Store SQL Connection object and return a reference ticket. - :param connection: SQLite3 or pymysql Connection object. - :returns: Session ID in which the database connection is referenced with. - """ - self.master_lock.acquire() - try: - # obtain a unique cryptographically random session_id - session_id = None - while not session_id or session_id in self.db: - session_id = os.urandom(16) - self.db[session_id] = connection - self.cursor[session_id] = self.db[session_id].cursor() - self.session_lock[session_id] = Lock() - return session_id - finally: - self.master_lock.release() - - def safe_execute(self, session_id: bytes, action: Callable) -> Any: - """ - Execute code on the Database Connection in a race-condition safe way. - :param session_id: Database Connection's Session ID. - :param action: Function or lambda in which to execute, it's provided `db` and `cursor` arguments. - :returns: Whatever `action` returns. - """ - if session_id not in self.db: - raise ValueError(f"Session ID {session_id!r} is invalid.") - self.master_lock.acquire() - self.session_lock[session_id].acquire() - try: - failures = 0 - while True: - try: - action( - db=self.db[session_id], - cursor=self.cursor[session_id] - ) - break - except sqlite3.OperationalError as e: - failures += 1 - delay = 3 * failures - print(f"AtomicSQL.safe_execute failed, {e}, retrying in {delay} seconds...") - time.sleep(delay) - if failures == 10: - raise ValueError("AtomicSQL.safe_execute failed too many time's. Aborting.") - return self.cursor[session_id] - finally: - self.session_lock[session_id].release() - self.master_lock.release() - - def commit(self, session_id: bytes) -> bool: - """ - Commit changes to the Database Connection immediately. - This isn't necessary to be run every time you make changes, just ensure it's run - at least before termination. - :param session_id: Database Connection's Session ID. - :returns: True if it committed. - """ - self.safe_execute( - session_id, - lambda db, cursor: db.commit() - ) - return True # todo ; actually check if db.commit worked diff --git a/devine/core/vault.py b/devine/core/vault.py index 01a7d71..bcefa07 100644 --- a/devine/core/vault.py +++ b/devine/core/vault.py @@ -31,11 +31,11 @@ class Vault(metaclass=ABCMeta): """Get All Keys from Vault by Service.""" @abstractmethod - def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool: + def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: """Add KID:KEY to the Vault.""" @abstractmethod - def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int: + def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int: """ Add Multiple Content Keys with Key IDs for Service to the Vault. Pre-existing Content Keys are ignored/skipped. diff --git a/devine/core/vaults.py b/devine/core/vaults.py index 49ea542..c94f73d 100644 --- a/devine/core/vaults.py +++ b/devine/core/vaults.py @@ -57,7 +57,7 @@ class Vaults: for vault in self.vaults: if vault != excluding: try: - success += vault.add_key(self.service, kid, key, commit=True) + success += vault.add_key(self.service, kid, key) except (PermissionError, NotImplementedError): pass return success @@ -70,7 +70,7 @@ class Vaults: success = 0 for vault in self.vaults: try: - success += bool(vault.add_keys(self.service, kid_keys, commit=True)) + success += bool(vault.add_keys(self.service, kid_keys)) except (PermissionError, NotImplementedError): pass return success diff --git a/devine/vaults/MySQL.py b/devine/vaults/MySQL.py index 221d3f6..c082e9e 100644 --- a/devine/vaults/MySQL.py +++ b/devine/vaults/MySQL.py @@ -1,5 +1,8 @@ from __future__ import annotations +import time +from queue import Empty, Queue +from threading import Lock from typing import Iterator, Optional, Union from uuid import UUID @@ -7,7 +10,6 @@ import pymysql from pymysql.cursors import DictCursor from devine.core.services import Services -from devine.core.utils.atomicsql import AtomicSQL from devine.core.vault import Vault @@ -21,15 +23,13 @@ class MySQL(Vault): """ super().__init__(name) self.slug = f"{host}:{database}:{username}" - self.con = pymysql.connect( + self.con_pool = ConnectionPool(dict( host=host, db=database, user=username, cursorclass=DictCursor, **kwargs - ) - self.adb = AtomicSQL() - self.ticket = self.adb.load(self.con) + ), 5) self.permissions = self.get_permissions() if not self.has_permission("SELECT"): @@ -43,37 +43,44 @@ class MySQL(Vault): if isinstance(kid, UUID): kid = kid.hex - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( # TODO: SQL injection risk f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s", - [kid, "0" * 32] + (kid, "0" * 32) ) - ).fetchone() - if not c: - return None - - return c["key_"] + cek = cursor.fetchone() + if not cek: + return None + return cek["key_"] + finally: + cursor.close() + self.con_pool.put(conn) def get_keys(self, service: str) -> Iterator[tuple[str, str]]: if not self.has_table(service): # no table, no keys, simple return None - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( # TODO: SQL injection risk f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s", - ["0" * 32] + ("0" * 32,) ) - ) + for row in cursor.fetchall(): + yield row["kid"], row["key_"] + finally: + cursor.close() + self.con_pool.put(conn) - for row in c.fetchall(): - yield row["kid"], row["key_"] - - def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool: + def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: if not key or key.count("0") == len(key): raise ValueError("You cannot add a NULL Content Key to a Vault.") @@ -82,39 +89,38 @@ class MySQL(Vault): if not self.has_table(service): try: - self.create_table(service, commit) + self.create_table(service) except PermissionError: return False if isinstance(kid, UUID): kid = kid.hex - if self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( # TODO: SQL injection risk f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s", - [kid, key] + (kid, key) ) - ).fetchone(): - # table already has this exact KID:KEY stored - return True - - self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + if cursor.fetchone(): + # table already has this exact KID:KEY stored + return True + cursor.execute( # TODO: SQL injection risk f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)", (kid, key) ) - ) - - if commit: - self.commit() + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) return True - def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int: + def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int: for kid, key in kid_keys.items(): if not key or key.count("0") == len(key): raise ValueError("You cannot add a NULL Content Key to a Vault.") @@ -124,7 +130,7 @@ class MySQL(Vault): if not self.has_table(service): try: - self.create_table(service, commit) + self.create_table(service) except PermissionError: return 0 @@ -139,40 +145,50 @@ class MySQL(Vault): for kid, key_ in kid_keys.items() } - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.executemany( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.executemany( # TODO: SQL injection risk f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)", kid_keys.items() ) - ) - - if commit: - self.commit() - - return c.rowcount + return cursor.rowcount + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) def get_services(self) -> Iterator[str]: - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute("SHOW TABLES") - ) - for table in c.fetchall(): - # each entry has a key named `Tables_in_` - yield Services.get_tag(list(table.values())[0]) + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute("SHOW TABLES") + for table in cursor.fetchall(): + # each entry has a key named `Tables_in_` + yield Services.get_tag(list(table.values())[0]) + finally: + cursor.close() + self.con_pool.put(conn) def has_table(self, name: str) -> bool: """Check if the Vault has a Table with the specified name.""" - return list(self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( - "SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s", - [self.con.db, name] - ) - ).fetchone().values())[0] == 1 + conn = self.con_pool.get() + cursor = conn.cursor() - def create_table(self, name: str, commit: bool = False): + try: + cursor.execute( + "SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s", + (conn.db, name) + ) + return list(cursor.fetchone().values())[0] == 1 + finally: + cursor.close() + self.con_pool.put(conn) + + def create_table(self, name: str): """Create a Table with the specified name if not yet created.""" if self.has_table(name): return @@ -180,9 +196,11 @@ class MySQL(Vault): if not self.has_permission("CREATE"): raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.") - self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( # TODO: SQL injection risk f""" CREATE TABLE IF NOT EXISTS {name} ( @@ -193,23 +211,30 @@ class MySQL(Vault): ); """ ) - ) - - if commit: - self.commit() + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) def get_permissions(self) -> list: """Get and parse Grants to a more easily usable list tuple array.""" - with self.con.cursor() as c: - c.execute("SHOW GRANTS") - grants = c.fetchall() + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() grants = [next(iter(x.values())) for x in grants] - grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)] - grants = [( - list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))), - location.replace("`", "").split(".") - ) for perms, location in grants] - return grants + grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)] + grants = [( + list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))), + location.replace("`", "").split(".") + ) for perms, location in grants] + return grants + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool: """Check if the current connection has a specific permission.""" @@ -220,6 +245,28 @@ class MySQL(Vault): grants = [x for x in grants if x[1][1] in (table, "*")] return bool(grants) - def commit(self): - """Commit any changes made that has not been written to db.""" - self.adb.commit(self.ticket) + +class ConnectionPool: + def __init__(self, con: dict, size: int): + self._con = con + self._size = size + self._pool = Queue(self._size) + self._lock = Lock() + + def _create_connection(self): + return pymysql.connect(**self._con) + + def get(self) -> pymysql.Connection: + while True: + try: + return self._pool.get(block=False) + except Empty: + with self._lock: + if self._pool.qsize() < self._size: + return self._create_connection() + else: + # pool full, wait before retrying + time.sleep(0.1) + + def put(self, conn: pymysql.Connection): + self._pool.put(conn) diff --git a/devine/vaults/SQLite.py b/devine/vaults/SQLite.py index 02c5307..a864b13 100644 --- a/devine/vaults/SQLite.py +++ b/devine/vaults/SQLite.py @@ -1,12 +1,15 @@ from __future__ import annotations import sqlite3 +import time from pathlib import Path +from queue import Empty, Queue +from sqlite3 import Connection +from threading import Lock from typing import Iterator, Optional, Union from uuid import UUID from devine.core.services import Services -from devine.core.utils.atomicsql import AtomicSQL from devine.core.vault import Vault @@ -17,9 +20,7 @@ class SQLite(Vault): super().__init__(name) self.path = Path(path).expanduser() # TODO: Use a DictCursor or such to get fetches as dict? - self.con = sqlite3.connect(self.path) - self.adb = AtomicSQL() - self.ticket = self.adb.load(self.con) + self.con_pool = ConnectionPool(self.path, 5) def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]: if not self.has_table(service): @@ -29,76 +30,82 @@ class SQLite(Vault): if isinstance(kid, UUID): kid = kid.hex - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( - # TODO: SQL injection risk - f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?", - [kid, "0" * 32] - ) - ).fetchone() - if not c: - return None + conn = self.con_pool.get() + cursor = conn.cursor() - return c[1] # `key_` + try: + cursor.execute( + f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?", + (kid, "0" * 32) + ) + cek = cursor.fetchone() + if not cek: + return None + return cek[1] + finally: + cursor.close() + self.con_pool.put(conn) def get_keys(self, service: str) -> Iterator[tuple[str, str]]: if not self.has_table(service): # no table, no keys, simple return None - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( - # TODO: SQL injection risk - f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?", - ["0" * 32] - ) - ) - for (kid, key_) in c.fetchall(): - yield kid, key_ - def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool: + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( + f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?", + ("0" * 32,) + ) + for (kid, key_) in cursor.fetchall(): + yield kid, key_ + finally: + cursor.close() + self.con_pool.put(conn) + + def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: if not key or key.count("0") == len(key): raise ValueError("You cannot add a NULL Content Key to a Vault.") if not self.has_table(service): - self.create_table(service, commit) + self.create_table(service) if isinstance(kid, UUID): kid = kid.hex - if self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( # TODO: SQL injection risk f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?", - [kid, key] + (kid, key) ) - ).fetchone(): - # table already has this exact KID:KEY stored - return True - - self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + if cursor.fetchone(): + # table already has this exact KID:KEY stored + return True + cursor.execute( # TODO: SQL injection risk f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)", (kid, key) ) - ) - - if commit: - self.commit() + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) return True - def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int: + def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int: for kid, key in kid_keys.items(): if not key or key.count("0") == len(key): raise ValueError("You cannot add a NULL Content Key to a Vault.") if not self.has_table(service): - self.create_table(service, commit) + self.create_table(service) if not isinstance(kid_keys, dict): raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}") @@ -111,47 +118,59 @@ class SQLite(Vault): for kid, key_ in kid_keys.items() } - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.executemany( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.executemany( # TODO: SQL injection risk f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)", kid_keys.items() ) - ) - - if commit: - self.commit() - - return c.rowcount + return cursor.rowcount + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) def get_services(self) -> Iterator[str]: - c = self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") - ) - for (name,) in c.fetchall(): - if name != "sqlite_sequence": - yield Services.get_tag(name) + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + for (name,) in cursor.fetchall(): + if name != "sqlite_sequence": + yield Services.get_tag(name) + finally: + cursor.close() + self.con_pool.put(conn) def has_table(self, name: str) -> bool: """Check if the Vault has a Table with the specified name.""" - return self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( - "SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?", - [name] - ) - ).fetchone()[0] == 1 + conn = self.con_pool.get() + cursor = conn.cursor() - def create_table(self, name: str, commit: bool = False): + try: + cursor.execute( + "SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?", + (name,) + ) + return cursor.fetchone()[0] == 1 + finally: + cursor.close() + self.con_pool.put(conn) + + def create_table(self, name: str): """Create a Table with the specified name if not yet created.""" if self.has_table(name): return - self.adb.safe_execute( - self.ticket, - lambda db, cursor: cursor.execute( + conn = self.con_pool.get() + cursor = conn.cursor() + + try: + cursor.execute( # TODO: SQL injection risk f""" CREATE TABLE IF NOT EXISTS {name} ( @@ -163,11 +182,33 @@ class SQLite(Vault): ); """ ) - ) + finally: + conn.commit() + cursor.close() + self.con_pool.put(conn) - if commit: - self.commit() - def commit(self): - """Commit any changes made that has not been written to db.""" - self.adb.commit(self.ticket) +class ConnectionPool: + def __init__(self, path: Union[str, Path], size: int): + self._path = path + self._size = size + self._pool = Queue(self._size) + self._lock = Lock() + + def _create_connection(self): + return sqlite3.connect(self._path) + + def get(self) -> Connection: + while True: + try: + return self._pool.get(block=False) + except Empty: + with self._lock: + if self._pool.qsize() < self._size: + return self._create_connection() + else: + # pool full, wait before retrying + time.sleep(0.1) + + def put(self, conn: Connection): + self._pool.put(conn)