From 1443dfaeccf28244ff6bf64c6864301a26252aa8 Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Wed, 22 Feb 2023 01:31:03 +0000 Subject: [PATCH] Replace Threading Pool system with Thread Storage in Vaults This fixes the usage of vaults across different threads. It now makes a truly unique connection for each thread. The previous code did this as well, but put back the connection from x thread, to re-use in y thread. Now it simply creates and reuses the connection on their own thread. Once the thread is closed, the data is now also garbage collected. This now reduces the risk of filling up memory over time. --- devine/vaults/MySQL.py | 58 +++++++++++++---------------------------- devine/vaults/SQLite.py | 53 ++++++++++++------------------------- 2 files changed, 34 insertions(+), 77 deletions(-) diff --git a/devine/vaults/MySQL.py b/devine/vaults/MySQL.py index c082e9e..c57ebe5 100644 --- a/devine/vaults/MySQL.py +++ b/devine/vaults/MySQL.py @@ -1,8 +1,6 @@ from __future__ import annotations -import time -from queue import Empty, Queue -from threading import Lock +import threading from typing import Iterator, Optional, Union from uuid import UUID @@ -23,13 +21,13 @@ class MySQL(Vault): """ super().__init__(name) self.slug = f"{host}:{database}:{username}" - self.con_pool = ConnectionPool(dict( + self.conn_factory = ConnectionFactory(dict( host=host, db=database, user=username, cursorclass=DictCursor, **kwargs - ), 5) + )) self.permissions = self.get_permissions() if not self.has_permission("SELECT"): @@ -43,7 +41,7 @@ class MySQL(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -58,14 +56,13 @@ class MySQL(Vault): return cek["key_"] finally: cursor.close() - self.con_pool.put(conn) def get_keys(self, service: str) -> Iterator[tuple[str, str]]: if not self.has_table(service): # no table, no keys, simple return None - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -78,7 +75,6 @@ class MySQL(Vault): yield row["kid"], row["key_"] finally: cursor.close() - self.con_pool.put(conn) def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: if not key or key.count("0") == len(key): @@ -96,7 +92,7 @@ class MySQL(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -116,7 +112,6 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) return True @@ -145,7 +140,7 @@ class MySQL(Vault): for kid, key_ in kid_keys.items() } - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -158,10 +153,9 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def get_services(self) -> Iterator[str]: - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -171,11 +165,10 @@ class MySQL(Vault): yield Services.get_tag(list(table.values())[0]) finally: cursor.close() - self.con_pool.put(conn) def has_table(self, name: str) -> bool: """Check if the Vault has a Table with the specified name.""" - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -186,7 +179,6 @@ class MySQL(Vault): return list(cursor.fetchone().values())[0] == 1 finally: cursor.close() - self.con_pool.put(conn) def create_table(self, name: str): """Create a Table with the specified name if not yet created.""" @@ -196,7 +188,7 @@ class MySQL(Vault): if not self.has_permission("CREATE"): raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.") - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -214,11 +206,10 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def get_permissions(self) -> list: """Get and parse Grants to a more easily usable list tuple array.""" - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -234,7 +225,6 @@ class MySQL(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool: """Check if the current connection has a specific permission.""" @@ -246,27 +236,15 @@ class MySQL(Vault): return bool(grants) -class ConnectionPool: - def __init__(self, con: dict, size: int): +class ConnectionFactory: + def __init__(self, con: dict): self._con = con - self._size = size - self._pool = Queue(self._size) - self._lock = Lock() + self._store = threading.local() - def _create_connection(self): + def _create_connection(self) -> pymysql.Connection: return pymysql.connect(**self._con) def get(self) -> pymysql.Connection: - while True: - try: - return self._pool.get(block=False) - except Empty: - with self._lock: - if self._pool.qsize() < self._size: - return self._create_connection() - else: - # pool full, wait before retrying - time.sleep(0.1) - - def put(self, conn: pymysql.Connection): - self._pool.put(conn) + if not hasattr(self._store, "conn"): + self._store.conn = self._create_connection() + return self._store.conn diff --git a/devine/vaults/SQLite.py b/devine/vaults/SQLite.py index a864b13..20fbf8d 100644 --- a/devine/vaults/SQLite.py +++ b/devine/vaults/SQLite.py @@ -1,11 +1,9 @@ from __future__ import annotations import sqlite3 -import time +import threading from pathlib import Path -from queue import Empty, Queue from sqlite3 import Connection -from threading import Lock from typing import Iterator, Optional, Union from uuid import UUID @@ -20,7 +18,7 @@ class SQLite(Vault): super().__init__(name) self.path = Path(path).expanduser() # TODO: Use a DictCursor or such to get fetches as dict? - self.con_pool = ConnectionPool(self.path, 5) + self.conn_factory = ConnectionFactory(self.path) def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]: if not self.has_table(service): @@ -30,7 +28,7 @@ class SQLite(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -44,14 +42,13 @@ class SQLite(Vault): return cek[1] finally: cursor.close() - self.con_pool.put(conn) def get_keys(self, service: str) -> Iterator[tuple[str, str]]: if not self.has_table(service): # no table, no keys, simple return None - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -63,7 +60,6 @@ class SQLite(Vault): yield kid, key_ finally: cursor.close() - self.con_pool.put(conn) def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: if not key or key.count("0") == len(key): @@ -75,7 +71,7 @@ class SQLite(Vault): if isinstance(kid, UUID): kid = kid.hex - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -95,7 +91,6 @@ class SQLite(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) return True @@ -118,7 +113,7 @@ class SQLite(Vault): for kid, key_ in kid_keys.items() } - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -131,10 +126,9 @@ class SQLite(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) def get_services(self) -> Iterator[str]: - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -144,11 +138,10 @@ class SQLite(Vault): yield Services.get_tag(name) finally: cursor.close() - self.con_pool.put(conn) def has_table(self, name: str) -> bool: """Check if the Vault has a Table with the specified name.""" - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -159,14 +152,13 @@ class SQLite(Vault): return cursor.fetchone()[0] == 1 finally: cursor.close() - self.con_pool.put(conn) def create_table(self, name: str): """Create a Table with the specified name if not yet created.""" if self.has_table(name): return - conn = self.con_pool.get() + conn = self.conn_factory.get() cursor = conn.cursor() try: @@ -185,30 +177,17 @@ class SQLite(Vault): finally: conn.commit() cursor.close() - self.con_pool.put(conn) -class ConnectionPool: - def __init__(self, path: Union[str, Path], size: int): +class ConnectionFactory: + def __init__(self, path: Union[str, Path]): self._path = path - self._size = size - self._pool = Queue(self._size) - self._lock = Lock() + self._store = threading.local() - def _create_connection(self): + def _create_connection(self) -> Connection: return sqlite3.connect(self._path) def get(self) -> Connection: - while True: - try: - return self._pool.get(block=False) - except Empty: - with self._lock: - if self._pool.qsize() < self._size: - return self._create_connection() - else: - # pool full, wait before retrying - time.sleep(0.1) - - def put(self, conn: Connection): - self._pool.put(conn) + if not hasattr(self._store, "conn"): + self._store.conn = self._create_connection() + return self._store.conn