diff --git a/devine/commands/kv.py b/devine/commands/kv.py
index 61e3f34..f237a37 100644
--- a/devine/commands/kv.py
+++ b/devine/commands/kv.py
@@ -89,7 +89,7 @@ def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) -
log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}")
try:
- added = to_vault.add_keys(service_, content_keys, commit=True)
+ added = to_vault.add_keys(service_, content_keys)
except PermissionError:
log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...")
continue
@@ -171,7 +171,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None:
for vault in vaults_:
log.info(f"Adding {total_count} Content Keys to {vault}")
- added_count = vault.add_keys(service, kid_keys, commit=True)
+ added_count = vault.add_keys(service, kid_keys)
existed_count = total_count - added_count
log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)")
diff --git a/devine/core/utils/atomicsql.py b/devine/core/utils/atomicsql.py
deleted file mode 100644
index dcee82d..0000000
--- a/devine/core/utils/atomicsql.py
+++ /dev/null
@@ -1,105 +0,0 @@
-"""
-AtomicSQL - Race-condition and Threading safe SQL Database Interface.
-Copyright (C) 2020-2023 rlaphoenix
-
-This program is free software: you can redistribute it and/or modify
-it under the terms of the GNU General Public License as published by
-the Free Software Foundation, either version 3 of the License, or
-(at your option) any later version.
-
-This program is distributed in the hope that it will be useful,
-but WITHOUT ANY WARRANTY; without even the implied warranty of
-MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-GNU General Public License for more details.
-
-You should have received a copy of the GNU General Public License
-along with this program. If not, see .
-"""
-
-import os
-import sqlite3
-import time
-from threading import Lock
-from typing import Any, Callable, Union
-
-import pymysql.cursors
-
-Connections = Union[sqlite3.Connection, pymysql.connections.Connection]
-Cursors = Union[sqlite3.Cursor, pymysql.cursors.Cursor]
-
-
-class AtomicSQL:
- """
- Race-condition and Threading safe SQL Database Interface.
- """
-
- def __init__(self) -> None:
- self.master_lock = Lock() # prevents race condition
- self.db: dict[bytes, Connections] = {} # used to hold the database connections and commit changes and such
- self.cursor: dict[bytes, Cursors] = {} # used to execute queries and receive results
- self.session_lock: dict[bytes, Lock] = {} # like master_lock, but per-session
-
- def load(self, connection: Connections) -> bytes:
- """
- Store SQL Connection object and return a reference ticket.
- :param connection: SQLite3 or pymysql Connection object.
- :returns: Session ID in which the database connection is referenced with.
- """
- self.master_lock.acquire()
- try:
- # obtain a unique cryptographically random session_id
- session_id = None
- while not session_id or session_id in self.db:
- session_id = os.urandom(16)
- self.db[session_id] = connection
- self.cursor[session_id] = self.db[session_id].cursor()
- self.session_lock[session_id] = Lock()
- return session_id
- finally:
- self.master_lock.release()
-
- def safe_execute(self, session_id: bytes, action: Callable) -> Any:
- """
- Execute code on the Database Connection in a race-condition safe way.
- :param session_id: Database Connection's Session ID.
- :param action: Function or lambda in which to execute, it's provided `db` and `cursor` arguments.
- :returns: Whatever `action` returns.
- """
- if session_id not in self.db:
- raise ValueError(f"Session ID {session_id!r} is invalid.")
- self.master_lock.acquire()
- self.session_lock[session_id].acquire()
- try:
- failures = 0
- while True:
- try:
- action(
- db=self.db[session_id],
- cursor=self.cursor[session_id]
- )
- break
- except sqlite3.OperationalError as e:
- failures += 1
- delay = 3 * failures
- print(f"AtomicSQL.safe_execute failed, {e}, retrying in {delay} seconds...")
- time.sleep(delay)
- if failures == 10:
- raise ValueError("AtomicSQL.safe_execute failed too many time's. Aborting.")
- return self.cursor[session_id]
- finally:
- self.session_lock[session_id].release()
- self.master_lock.release()
-
- def commit(self, session_id: bytes) -> bool:
- """
- Commit changes to the Database Connection immediately.
- This isn't necessary to be run every time you make changes, just ensure it's run
- at least before termination.
- :param session_id: Database Connection's Session ID.
- :returns: True if it committed.
- """
- self.safe_execute(
- session_id,
- lambda db, cursor: db.commit()
- )
- return True # todo ; actually check if db.commit worked
diff --git a/devine/core/vault.py b/devine/core/vault.py
index 01a7d71..bcefa07 100644
--- a/devine/core/vault.py
+++ b/devine/core/vault.py
@@ -31,11 +31,11 @@ class Vault(metaclass=ABCMeta):
"""Get All Keys from Vault by Service."""
@abstractmethod
- def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
+ def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
"""Add KID:KEY to the Vault."""
@abstractmethod
- def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
+ def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
"""
Add Multiple Content Keys with Key IDs for Service to the Vault.
Pre-existing Content Keys are ignored/skipped.
diff --git a/devine/core/vaults.py b/devine/core/vaults.py
index 49ea542..c94f73d 100644
--- a/devine/core/vaults.py
+++ b/devine/core/vaults.py
@@ -57,7 +57,7 @@ class Vaults:
for vault in self.vaults:
if vault != excluding:
try:
- success += vault.add_key(self.service, kid, key, commit=True)
+ success += vault.add_key(self.service, kid, key)
except (PermissionError, NotImplementedError):
pass
return success
@@ -70,7 +70,7 @@ class Vaults:
success = 0
for vault in self.vaults:
try:
- success += bool(vault.add_keys(self.service, kid_keys, commit=True))
+ success += bool(vault.add_keys(self.service, kid_keys))
except (PermissionError, NotImplementedError):
pass
return success
diff --git a/devine/vaults/MySQL.py b/devine/vaults/MySQL.py
index 221d3f6..c082e9e 100644
--- a/devine/vaults/MySQL.py
+++ b/devine/vaults/MySQL.py
@@ -1,5 +1,8 @@
from __future__ import annotations
+import time
+from queue import Empty, Queue
+from threading import Lock
from typing import Iterator, Optional, Union
from uuid import UUID
@@ -7,7 +10,6 @@ import pymysql
from pymysql.cursors import DictCursor
from devine.core.services import Services
-from devine.core.utils.atomicsql import AtomicSQL
from devine.core.vault import Vault
@@ -21,15 +23,13 @@ class MySQL(Vault):
"""
super().__init__(name)
self.slug = f"{host}:{database}:{username}"
- self.con = pymysql.connect(
+ self.con_pool = ConnectionPool(dict(
host=host,
db=database,
user=username,
cursorclass=DictCursor,
**kwargs
- )
- self.adb = AtomicSQL()
- self.ticket = self.adb.load(self.con)
+ ), 5)
self.permissions = self.get_permissions()
if not self.has_permission("SELECT"):
@@ -43,37 +43,44 @@ class MySQL(Vault):
if isinstance(kid, UUID):
kid = kid.hex
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
# TODO: SQL injection risk
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
- [kid, "0" * 32]
+ (kid, "0" * 32)
)
- ).fetchone()
- if not c:
- return None
-
- return c["key_"]
+ cek = cursor.fetchone()
+ if not cek:
+ return None
+ return cek["key_"]
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
# TODO: SQL injection risk
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
- ["0" * 32]
+ ("0" * 32,)
)
- )
+ for row in cursor.fetchall():
+ yield row["kid"], row["key_"]
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
- for row in c.fetchall():
- yield row["kid"], row["key_"]
-
- def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
+ def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
@@ -82,39 +89,38 @@ class MySQL(Vault):
if not self.has_table(service):
try:
- self.create_table(service, commit)
+ self.create_table(service)
except PermissionError:
return False
if isinstance(kid, UUID):
kid = kid.hex
- if self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
# TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
- [kid, key]
+ (kid, key)
)
- ).fetchone():
- # table already has this exact KID:KEY stored
- return True
-
- self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ if cursor.fetchone():
+ # table already has this exact KID:KEY stored
+ return True
+ cursor.execute(
# TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
(kid, key)
)
- )
-
- if commit:
- self.commit()
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
return True
- def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
+ def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items():
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
@@ -124,7 +130,7 @@ class MySQL(Vault):
if not self.has_table(service):
try:
- self.create_table(service, commit)
+ self.create_table(service)
except PermissionError:
return 0
@@ -139,40 +145,50 @@ class MySQL(Vault):
for kid, key_ in kid_keys.items()
}
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.executemany(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.executemany(
# TODO: SQL injection risk
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
kid_keys.items()
)
- )
-
- if commit:
- self.commit()
-
- return c.rowcount
+ return cursor.rowcount
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
def get_services(self) -> Iterator[str]:
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute("SHOW TABLES")
- )
- for table in c.fetchall():
- # each entry has a key named `Tables_in_`
- yield Services.get_tag(list(table.values())[0])
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute("SHOW TABLES")
+ for table in cursor.fetchall():
+ # each entry has a key named `Tables_in_`
+ yield Services.get_tag(list(table.values())[0])
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
- return list(self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
- "SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
- [self.con.db, name]
- )
- ).fetchone().values())[0] == 1
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
- def create_table(self, name: str, commit: bool = False):
+ try:
+ cursor.execute(
+ "SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
+ (conn.db, name)
+ )
+ return list(cursor.fetchone().values())[0] == 1
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
+
+ def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
if self.has_table(name):
return
@@ -180,9 +196,11 @@ class MySQL(Vault):
if not self.has_permission("CREATE"):
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
- self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
# TODO: SQL injection risk
f"""
CREATE TABLE IF NOT EXISTS {name} (
@@ -193,23 +211,30 @@ class MySQL(Vault):
);
"""
)
- )
-
- if commit:
- self.commit()
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
def get_permissions(self) -> list:
"""Get and parse Grants to a more easily usable list tuple array."""
- with self.con.cursor() as c:
- c.execute("SHOW GRANTS")
- grants = c.fetchall()
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute("SHOW GRANTS")
+ grants = cursor.fetchall()
grants = [next(iter(x.values())) for x in grants]
- grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
- grants = [(
- list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
- location.replace("`", "").split(".")
- ) for perms, location in grants]
- return grants
+ grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
+ grants = [(
+ list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
+ location.replace("`", "").split(".")
+ ) for perms, location in grants]
+ return grants
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
"""Check if the current connection has a specific permission."""
@@ -220,6 +245,28 @@ class MySQL(Vault):
grants = [x for x in grants if x[1][1] in (table, "*")]
return bool(grants)
- def commit(self):
- """Commit any changes made that has not been written to db."""
- self.adb.commit(self.ticket)
+
+class ConnectionPool:
+ def __init__(self, con: dict, size: int):
+ self._con = con
+ self._size = size
+ self._pool = Queue(self._size)
+ self._lock = Lock()
+
+ def _create_connection(self):
+ return pymysql.connect(**self._con)
+
+ def get(self) -> pymysql.Connection:
+ while True:
+ try:
+ return self._pool.get(block=False)
+ except Empty:
+ with self._lock:
+ if self._pool.qsize() < self._size:
+ return self._create_connection()
+ else:
+ # pool full, wait before retrying
+ time.sleep(0.1)
+
+ def put(self, conn: pymysql.Connection):
+ self._pool.put(conn)
diff --git a/devine/vaults/SQLite.py b/devine/vaults/SQLite.py
index 02c5307..a864b13 100644
--- a/devine/vaults/SQLite.py
+++ b/devine/vaults/SQLite.py
@@ -1,12 +1,15 @@
from __future__ import annotations
import sqlite3
+import time
from pathlib import Path
+from queue import Empty, Queue
+from sqlite3 import Connection
+from threading import Lock
from typing import Iterator, Optional, Union
from uuid import UUID
from devine.core.services import Services
-from devine.core.utils.atomicsql import AtomicSQL
from devine.core.vault import Vault
@@ -17,9 +20,7 @@ class SQLite(Vault):
super().__init__(name)
self.path = Path(path).expanduser()
# TODO: Use a DictCursor or such to get fetches as dict?
- self.con = sqlite3.connect(self.path)
- self.adb = AtomicSQL()
- self.ticket = self.adb.load(self.con)
+ self.con_pool = ConnectionPool(self.path, 5)
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
if not self.has_table(service):
@@ -29,76 +30,82 @@ class SQLite(Vault):
if isinstance(kid, UUID):
kid = kid.hex
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
- # TODO: SQL injection risk
- f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
- [kid, "0" * 32]
- )
- ).fetchone()
- if not c:
- return None
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
- return c[1] # `key_`
+ try:
+ cursor.execute(
+ f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
+ (kid, "0" * 32)
+ )
+ cek = cursor.fetchone()
+ if not cek:
+ return None
+ return cek[1]
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
if not self.has_table(service):
# no table, no keys, simple
return None
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
- # TODO: SQL injection risk
- f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
- ["0" * 32]
- )
- )
- for (kid, key_) in c.fetchall():
- yield kid, key_
- def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
+ f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
+ ("0" * 32,)
+ )
+ for (kid, key_) in cursor.fetchall():
+ yield kid, key_
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
+
+ def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_table(service):
- self.create_table(service, commit)
+ self.create_table(service)
if isinstance(kid, UUID):
kid = kid.hex
- if self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
# TODO: SQL injection risk
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
- [kid, key]
+ (kid, key)
)
- ).fetchone():
- # table already has this exact KID:KEY stored
- return True
-
- self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ if cursor.fetchone():
+ # table already has this exact KID:KEY stored
+ return True
+ cursor.execute(
# TODO: SQL injection risk
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
(kid, key)
)
- )
-
- if commit:
- self.commit()
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
return True
- def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
+ def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
for kid, key in kid_keys.items():
if not key or key.count("0") == len(key):
raise ValueError("You cannot add a NULL Content Key to a Vault.")
if not self.has_table(service):
- self.create_table(service, commit)
+ self.create_table(service)
if not isinstance(kid_keys, dict):
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
@@ -111,47 +118,59 @@ class SQLite(Vault):
for kid, key_ in kid_keys.items()
}
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.executemany(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.executemany(
# TODO: SQL injection risk
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
kid_keys.items()
)
- )
-
- if commit:
- self.commit()
-
- return c.rowcount
+ return cursor.rowcount
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
def get_services(self) -> Iterator[str]:
- c = self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
- )
- for (name,) in c.fetchall():
- if name != "sqlite_sequence":
- yield Services.get_tag(name)
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
+ for (name,) in cursor.fetchall():
+ if name != "sqlite_sequence":
+ yield Services.get_tag(name)
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
def has_table(self, name: str) -> bool:
"""Check if the Vault has a Table with the specified name."""
- return self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
- "SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
- [name]
- )
- ).fetchone()[0] == 1
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
- def create_table(self, name: str, commit: bool = False):
+ try:
+ cursor.execute(
+ "SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
+ (name,)
+ )
+ return cursor.fetchone()[0] == 1
+ finally:
+ cursor.close()
+ self.con_pool.put(conn)
+
+ def create_table(self, name: str):
"""Create a Table with the specified name if not yet created."""
if self.has_table(name):
return
- self.adb.safe_execute(
- self.ticket,
- lambda db, cursor: cursor.execute(
+ conn = self.con_pool.get()
+ cursor = conn.cursor()
+
+ try:
+ cursor.execute(
# TODO: SQL injection risk
f"""
CREATE TABLE IF NOT EXISTS {name} (
@@ -163,11 +182,33 @@ class SQLite(Vault):
);
"""
)
- )
+ finally:
+ conn.commit()
+ cursor.close()
+ self.con_pool.put(conn)
- if commit:
- self.commit()
- def commit(self):
- """Commit any changes made that has not been written to db."""
- self.adb.commit(self.ticket)
+class ConnectionPool:
+ def __init__(self, path: Union[str, Path], size: int):
+ self._path = path
+ self._size = size
+ self._pool = Queue(self._size)
+ self._lock = Lock()
+
+ def _create_connection(self):
+ return sqlite3.connect(self._path)
+
+ def get(self) -> Connection:
+ while True:
+ try:
+ return self._pool.get(block=False)
+ except Empty:
+ with self._lock:
+ if self._pool.qsize() < self._size:
+ return self._create_connection()
+ else:
+ # pool full, wait before retrying
+ time.sleep(0.1)
+
+ def put(self, conn: Connection):
+ self._pool.put(conn)