diff --git a/devine/vaults/MySQL.py b/devine/vaults/MySQL.py index c082e9e..c57ebe5 100644 --- a/devine/vaults/MySQL.py +++ b/devine/vaults/MySQL.py @@ -1,8 +1,6 @@ from __future__ import annotations -import time -from queue import Empty, Queue -from threading import Lock +import threading from typing import Iterator, Optional, Union from uuid import UUID @@ -23,13 +21,13 @@ class MySQL(Vault): """ super().__init__(name) self.slug = f"{host}:{database}:{username}" - self.con_pool = ConnectionPool(dict( + self.conn_factory = ConnectionFactory(dict( host=host, db=database, user=username, cursorclass=DictCursor, **kwargs - ), 5) + )) self.permissions = self.get_permissions() if not self.has_permission("SELECT"): @@ -43,7 +41,7 @@ class MySQL(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -58,14 +56,13 @@ class MySQL(Vault): return cek["key_"] finally: cursor.close() - self.con_pool.put(conn) def get_keys(self, service: str) -> Iterator[tuple[str, str]]: if not self.has_table(service): # no table, no keys, simple return None - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -78,7 +75,6 @@ class MySQL(Vault): yield row["kid"], row["key_"] finally: cursor.close() - self.con_pool.put(conn) def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: if not key or key.count("0") == len(key): @@ -96,7 +92,7 @@ class MySQL(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -116,7 +112,6 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) return True @@ -145,7 +140,7 @@ class MySQL(Vault): for kid, key_ in kid_keys.items() } - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -158,10 +153,9 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def get_services(self) -> Iterator[str]: - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -171,11 +165,10 @@ class MySQL(Vault): yield Services.get_tag(list(table.values())[0]) finally: cursor.close() - self.con_pool.put(conn) def has_table(self, name: str) -> bool: """Check if the Vault has a Table with the specified name.""" - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -186,7 +179,6 @@ class MySQL(Vault): return list(cursor.fetchone().values())[0] == 1 finally: cursor.close() - self.con_pool.put(conn) def create_table(self, name: str): """Create a Table with the specified name if not yet created.""" @@ -196,7 +188,7 @@ class MySQL(Vault): if not self.has_permission("CREATE"): raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.") - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -214,11 +206,10 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def get_permissions(self) -> list: """Get and parse Grants to a more easily usable list tuple array.""" - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -234,7 +225,6 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool: """Check if the current connection has a specific permission.""" @@ -246,27 +236,15 @@ class MySQL(Vault): return bool(grants) -class ConnectionPool: - def __init__(self, con: dict, size: int): +class ConnectionFactory: + def __init__(self, con: dict): self._con = con - self._size = size - self._pool = Queue(self._size) - self._lock = Lock() + self._store = threading.local() - def _create_connection(self): + def _create_connection(self) -> pymysql.Connection: return pymysql.connect(**self._con) def get(self) -> pymysql.Connection: - while True: - try: - return self._pool.get(block=False) - except Empty: - with self._lock: - if self._pool.qsize() < self._size: - return self._create_connection() - else: - # pool full, wait before retrying - time.sleep(0.1) - - def put(self, conn: pymysql.Connection): - self._pool.put(conn) + if not hasattr(self._store, "conn"): + self._store.conn = self._create_connection() + return self._store.conn diff --git a/devine/vaults/SQLite.py b/devine/vaults/SQLite.py index a864b13..20fbf8d 100644 --- a/devine/vaults/SQLite.py +++ b/devine/vaults/SQLite.py @@ -1,11 +1,9 @@ from __future__ import annotations import sqlite3 -import time +import threading from pathlib import Path -from queue import Empty, Queue from sqlite3 import Connection -from threading import Lock from typing import Iterator, Optional, Union from uuid import UUID @@ -20,7 +18,7 @@ class SQLite(Vault): super().__init__(name) self.path = Path(path).expanduser() # TODO: Use a DictCursor or such to get fetches as dict? - self.con_pool = ConnectionPool(self.path, 5) + self.conn_factory = ConnectionFactory(self.path) def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]: if not self.has_table(service): @@ -30,7 +28,7 @@ class SQLite(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -44,14 +42,13 @@ class SQLite(Vault): return cek[1] finally: cursor.close() - self.con_pool.put(conn) def get_keys(self, service: str) -> Iterator[tuple[str, str]]: if not self.has_table(service): # no table, no keys, simple return None - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -63,7 +60,6 @@ class SQLite(Vault): yield kid, key_ finally: cursor.close() - self.con_pool.put(conn) def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: if not key or key.count("0") == len(key): @@ -75,7 +71,7 @@ class SQLite(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -95,7 +91,6 @@ class SQLite(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) return True @@ -118,7 +113,7 @@ class SQLite(Vault): for kid, key_ in kid_keys.items() } - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -131,10 +126,9 @@ class SQLite(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def get_services(self) -> Iterator[str]: - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -144,11 +138,10 @@ class SQLite(Vault): yield Services.get_tag(name) finally: cursor.close() - self.con_pool.put(conn) def has_table(self, name: str) -> bool: """Check if the Vault has a Table with the specified name.""" - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -159,14 +152,13 @@ class SQLite(Vault): return cursor.fetchone()[0] == 1 finally: cursor.close() - self.con_pool.put(conn) def create_table(self, name: str): """Create a Table with the specified name if not yet created.""" if self.has_table(name): return - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -185,30 +177,17 @@ class SQLite(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) -class ConnectionPool: - def __init__(self, path: Union[str, Path], size: int): +class ConnectionFactory: + def __init__(self, path: Union[str, Path]): self._path = path - self._size = size - self._pool = Queue(self._size) - self._lock = Lock() + self._store = threading.local() - def _create_connection(self): + def _create_connection(self) -> Connection: return sqlite3.connect(self._path) def get(self) -> Connection: - while True: - try: - return self._pool.get(block=False) - except Empty: - with self._lock: - if self._pool.qsize() < self._size: - return self._create_connection() - else: - # pool full, wait before retrying - time.sleep(0.1) - - def put(self, conn: Connection): - self._pool.put(conn) + if not hasattr(self._store, "conn"): + self._store.conn = self._create_connection() + return self._store.conn