Replace Threading Pool system with Thread Storage in Vaults

This fixes the usage of vaults across different threads. It now makes a truly unique connection for each thread. The previous code did this as well, but put back the connection from x thread, to re-use in y thread. Now it simply creates and reuses the connection on their own thread.

Once the thread is closed, the data is now also garbage collected. This now reduces the risk of filling up memory over time.
This commit is contained in:
rlaphoenix 2023-02-22 01:31:03 +00:00
parent 4e875f5ffc
commit 1443dfaecc
2 changed files with 34 additions and 77 deletions

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import time
from queue import Empty, Queue
from threading import Lock
import threading
from typing import Iterator, Optional, Union
from uuid import UUID
@ -23,13 +21,13 @@ class MySQL(Vault):
"""
super().__init__(name)
self.slug = f"{host}:{database}:{username}"
self.con_pool = ConnectionPool(dict(
self.conn_factory = ConnectionFactory(dict(
host=host,
db=database,
user=username,
cursorclass=DictCursor,
**kwargs
), 5)
))
self.permissions = self.get_permissions()
if not self.has_permission("SELECT"):
@ -43,7 +41,7 @@ class MySQL(Vault):
if isinstance(kid, UUID):
kid = kid.hex
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -58,14 +56,13 @@ class MySQL(Vault):
return cek["key_"]
finally:
cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -78,7 +75,6 @@ class MySQL(Vault):
yield row["kid"], row["key_"]
finally:
cursor.close()
self.con_pool.put(conn)
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
@ -96,7 +92,7 @@ class MySQL(Vault):
if isinstance(kid, UUID):
kid = kid.hex
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -116,7 +112,6 @@ class MySQL(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
return True
@ -145,7 +140,7 @@ class MySQL(Vault):
for kid, key_ in kid_keys.items()
}
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -158,10 +153,9 @@ class MySQL(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def get_services(self) -> Iterator[str]:
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -171,11 +165,10 @@ class MySQL(Vault):
yield Services.get_tag(list(table.values())[0])
finally:
cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -186,7 +179,6 @@ class MySQL(Vault):
return list(cursor.fetchone().values())[0] == 1
finally:
cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
@ -196,7 +188,7 @@ class MySQL(Vault):
if not self.has_permission("CREATE"):
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -214,11 +206,10 @@ class MySQL(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def get_permissions(self) -> list:
"""Get and parse Grants to a more easily usable list tuple array."""
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -234,7 +225,6 @@ class MySQL(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
"""Check if the current connection has a specific permission."""
@ -246,27 +236,15 @@ class MySQL(Vault):
return bool(grants)
class ConnectionPool:
def __init__(self, con: dict, size: int):
class ConnectionFactory:
def __init__(self, con: dict):
self._con = con
self._size = size
self._pool = Queue(self._size)
self._lock = Lock()
self._store = threading.local()
def _create_connection(self):
def _create_connection(self) -> pymysql.Connection:
return pymysql.connect(**self._con)
def get(self) -> pymysql.Connection:
while True:
try:
return self._pool.get(block=False)
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: pymysql.Connection):
self._pool.put(conn)
if not hasattr(self._store, "conn"):
self._store.conn = self._create_connection()
return self._store.conn

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import sqlite3
import time
import threading
from pathlib import Path
from queue import Empty, Queue
from sqlite3 import Connection
from threading import Lock
from typing import Iterator, Optional, Union
from uuid import UUID
@ -20,7 +18,7 @@ class SQLite(Vault):
super().__init__(name)
self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict?
self.con_pool = ConnectionPool(self.path, 5)
self.conn_factory = ConnectionFactory(self.path)
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
if not self.has_table(service):
@ -30,7 +28,7 @@ class SQLite(Vault):
if isinstance(kid, UUID):
kid = kid.hex
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -44,14 +42,13 @@ class SQLite(Vault):
return cek[1]
finally:
cursor.close()
self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -63,7 +60,6 @@ class SQLite(Vault):
yield kid, key_
finally:
cursor.close()
self.con_pool.put(conn)
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
@ -75,7 +71,7 @@ class SQLite(Vault):
if isinstance(kid, UUID):
kid = kid.hex
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -95,7 +91,6 @@ class SQLite(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
return True
@ -118,7 +113,7 @@ class SQLite(Vault):
for kid, key_ in kid_keys.items()
}
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -131,10 +126,9 @@ class SQLite(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
def get_services(self) -> Iterator[str]:
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -144,11 +138,10 @@ class SQLite(Vault):
yield Services.get_tag(name)
finally:
cursor.close()
self.con_pool.put(conn)
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -159,14 +152,13 @@ class SQLite(Vault):
return cursor.fetchone()[0] == 1
finally:
cursor.close()
self.con_pool.put(conn)
def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
if self.has_table(name):
return
conn = self.con_pool.get()
conn = self.conn_factory.get()
cursor = conn.cursor()
try:
@ -185,30 +177,17 @@ class SQLite(Vault):
finally:
conn.commit()
cursor.close()
self.con_pool.put(conn)
class ConnectionPool:
def __init__(self, path: Union[str, Path], size: int):
class ConnectionFactory:
def __init__(self, path: Union[str, Path]):
self._path = path
self._size = size
self._pool = Queue(self._size)
self._lock = Lock()
self._store = threading.local()
def _create_connection(self):
def _create_connection(self) -> Connection:
return sqlite3.connect(self._path)
def get(self) -> Connection:
while True:
try:
return self._pool.get(block=False)
except Empty:
with self._lock:
if self._pool.qsize() < self._size:
return self._create_connection()
else:
# pool full, wait before retrying
time.sleep(0.1)
def put(self, conn: Connection):
self._pool.put(conn)
if not hasattr(self._store, "conn"):
self._store.conn = self._create_connection()
return self._store.conn