Replace Threading Pool system with Thread Storage in Vaults

This fixes the usage of vaults across different threads. It now makes a truly unique connection for each thread. The previous code did this as well, but put back the connection from x thread, to re-use in y thread. Now it simply creates and reuses the connection on their own thread.

Once the thread is closed, the data is now also garbage collected. This now reduces the risk of filling up memory over time.
This commit is contained in:
rlaphoenix 2023-02-22 01:31:03 +00:00
parent 4e875f5ffc
commit 1443dfaecc
2 changed files with 34 additions and 77 deletions

View File

@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
import time import threading
from queue import Empty, Queue
from threading import Lock
from typing import Iterator, Optional, Union from typing import Iterator, Optional, Union
from uuid import UUID from uuid import UUID
@ -23,13 +21,13 @@ class MySQL(Vault):
""" """
super().__init__(name) super().__init__(name)
self.slug = f"{host}:{database}:{username}" self.slug = f"{host}:{database}:{username}"
self.con_pool = ConnectionPool(dict( self.conn_factory = ConnectionFactory(dict(
host=host, host=host,
db=database, db=database,
user=username, user=username,
cursorclass=DictCursor, cursorclass=DictCursor,
**kwargs **kwargs
), 5) ))
self.permissions = self.get_permissions() self.permissions = self.get_permissions()
if not self.has_permission("SELECT"): if not self.has_permission("SELECT"):
@ -43,7 +41,7 @@ class MySQL(Vault):
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -58,14 +56,13 @@ class MySQL(Vault):
return cek["key_"] return cek["key_"]
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]: def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service): if not self.has_table(service):
# no table, no keys, simple # no table, no keys, simple
return None return None
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -78,7 +75,6 @@ class MySQL(Vault):
yield row["kid"], row["key_"] yield row["kid"], row["key_"]
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key): if not key or key.count("0") == len(key):
@ -96,7 +92,7 @@ class MySQL(Vault):
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -116,7 +112,6 @@ class MySQL(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
return True return True
@ -145,7 +140,7 @@ class MySQL(Vault):
for kid, key_ in kid_keys.items() for kid, key_ in kid_keys.items()
} }
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -158,10 +153,9 @@ class MySQL(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
def get_services(self) -> Iterator[str]: def get_services(self) -> Iterator[str]:
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -171,11 +165,10 @@ class MySQL(Vault):
yield Services.get_tag(list(table.values())[0]) yield Services.get_tag(list(table.values())[0])
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool: def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name.""" """Check if the Vault has a Table with the specified name."""
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -186,7 +179,6 @@ class MySQL(Vault):
return list(cursor.fetchone().values())[0] == 1 return list(cursor.fetchone().values())[0] == 1
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str): def create_table(self, name: str):
"""Create a Table with the specified name if not yet created.""" """Create a Table with the specified name if not yet created."""
@ -196,7 +188,7 @@ class MySQL(Vault):
if not self.has_permission("CREATE"): if not self.has_permission("CREATE"):
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.") raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -214,11 +206,10 @@ class MySQL(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
def get_permissions(self) -> list: def get_permissions(self) -> list:
"""Get and parse Grants to a more easily usable list tuple array.""" """Get and parse Grants to a more easily usable list tuple array."""
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -234,7 +225,6 @@ class MySQL(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool: def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
"""Check if the current connection has a specific permission.""" """Check if the current connection has a specific permission."""
@ -246,27 +236,15 @@ class MySQL(Vault):
return bool(grants) return bool(grants)
class ConnectionPool: class ConnectionFactory:
def __init__(self, con: dict, size: int): def __init__(self, con: dict):
self._con = con self._con = con
self._size = size self._store = threading.local()
self._pool = Queue(self._size)
self._lock = Lock()
def _create_connection(self): def _create_connection(self) -> pymysql.Connection:
return pymysql.connect(**self._con) return pymysql.connect(**self._con)
def get(self) -> pymysql.Connection: def get(self) -> pymysql.Connection:
while True: if not hasattr(self._store, "conn"):
try: self._store.conn = self._create_connection()
return self._pool.get(block=False) return self._store.conn
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: pymysql.Connection):
self._pool.put(conn)

View File

@ -1,11 +1,9 @@
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
import time import threading
from pathlib import Path from pathlib import Path
from queue import Empty, Queue
from sqlite3 import Connection from sqlite3 import Connection
from threading import Lock
from typing import Iterator, Optional, Union from typing import Iterator, Optional, Union
from uuid import UUID from uuid import UUID
@ -20,7 +18,7 @@ class SQLite(Vault):
super().__init__(name) super().__init__(name)
self.path = Path(path).expanduser() self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict? # TODO: Use a DictCursor or such to get fetches as dict?
self.con_pool = ConnectionPool(self.path, 5) self.conn_factory = ConnectionFactory(self.path)
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]: def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
if not self.has_table(service): if not self.has_table(service):
@ -30,7 +28,7 @@ class SQLite(Vault):
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -44,14 +42,13 @@ class SQLite(Vault):
return cek[1] return cek[1]
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]: def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service): if not self.has_table(service):
# no table, no keys, simple # no table, no keys, simple
return None return None
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -63,7 +60,6 @@ class SQLite(Vault):
yield kid, key_ yield kid, key_
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool: def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key): if not key or key.count("0") == len(key):
@ -75,7 +71,7 @@ class SQLite(Vault):
if isinstance(kid, UUID): if isinstance(kid, UUID):
kid = kid.hex kid = kid.hex
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -95,7 +91,6 @@ class SQLite(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
return True return True
@ -118,7 +113,7 @@ class SQLite(Vault):
for kid, key_ in kid_keys.items() for kid, key_ in kid_keys.items()
} }
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -131,10 +126,9 @@ class SQLite(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
def get_services(self) -> Iterator[str]: def get_services(self) -> Iterator[str]:
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -144,11 +138,10 @@ class SQLite(Vault):
yield Services.get_tag(name) yield Services.get_tag(name)
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool: def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name.""" """Check if the Vault has a Table with the specified name."""
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -159,14 +152,13 @@ class SQLite(Vault):
return cursor.fetchone()[0] == 1 return cursor.fetchone()[0] == 1
finally: finally:
cursor.close() cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str): def create_table(self, name: str):
"""Create a Table with the specified name if not yet created.""" """Create a Table with the specified name if not yet created."""
if self.has_table(name): if self.has_table(name):
return return
conn = self.con_pool.get() conn = self.conn_factory.get()
cursor = conn.cursor() cursor = conn.cursor()
try: try:
@ -185,30 +177,17 @@ class SQLite(Vault):
finally: finally:
conn.commit() conn.commit()
cursor.close() cursor.close()
self.con_pool.put(conn)
class ConnectionPool: class ConnectionFactory:
def __init__(self, path: Union[str, Path], size: int): def __init__(self, path: Union[str, Path]):
self._path = path self._path = path
self._size = size self._store = threading.local()
self._pool = Queue(self._size)
self._lock = Lock()
def _create_connection(self): def _create_connection(self) -> Connection:
return sqlite3.connect(self._path) return sqlite3.connect(self._path)
def get(self) -> Connection: def get(self) -> Connection:
while True: if not hasattr(self._store, "conn"):
try: self._store.conn = self._create_connection()
return self._pool.get(block=False) return self._store.conn
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: Connection):
self._pool.put(conn)