forked from DRMTalks/devine
Remove AtomicSQL, use Connection Pools in Vaults
This allows the use of vaults in any thread, while keeping database-file level thread safety. AtomicSQL doesn't actually do anything that useful.
This commit is contained in:
parent
707469d252
commit
c925cb8af9
|
@ -89,7 +89,7 @@ def copy(to_vault: str, from_vaults: list[str], service: Optional[str] = None) -
|
|||
log.info(f"Adding {total_count} Content Keys to {to_vault} for {service_}")
|
||||
|
||||
try:
|
||||
added = to_vault.add_keys(service_, content_keys, commit=True)
|
||||
added = to_vault.add_keys(service_, content_keys)
|
||||
except PermissionError:
|
||||
log.warning(f" - No permission to create table ({service_}) in {to_vault}, skipping...")
|
||||
continue
|
||||
|
@ -171,7 +171,7 @@ def add(file: Path, service: str, vaults: list[str]) -> None:
|
|||
|
||||
for vault in vaults_:
|
||||
log.info(f"Adding {total_count} Content Keys to {vault}")
|
||||
added_count = vault.add_keys(service, kid_keys, commit=True)
|
||||
added_count = vault.add_keys(service, kid_keys)
|
||||
existed_count = total_count - added_count
|
||||
log.info(f"{vault}: {added_count} newly added, {existed_count} already existed (skipped)")
|
||||
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
"""
|
||||
AtomicSQL - Race-condition and Threading safe SQL Database Interface.
|
||||
Copyright (C) 2020-2023 rlaphoenix
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import pymysql.cursors
|
||||
|
||||
Connections = Union[sqlite3.Connection, pymysql.connections.Connection]
|
||||
Cursors = Union[sqlite3.Cursor, pymysql.cursors.Cursor]
|
||||
|
||||
|
||||
class AtomicSQL:
|
||||
"""
|
||||
Race-condition and Threading safe SQL Database Interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.master_lock = Lock() # prevents race condition
|
||||
self.db: dict[bytes, Connections] = {} # used to hold the database connections and commit changes and such
|
||||
self.cursor: dict[bytes, Cursors] = {} # used to execute queries and receive results
|
||||
self.session_lock: dict[bytes, Lock] = {} # like master_lock, but per-session
|
||||
|
||||
def load(self, connection: Connections) -> bytes:
|
||||
"""
|
||||
Store SQL Connection object and return a reference ticket.
|
||||
:param connection: SQLite3 or pymysql Connection object.
|
||||
:returns: Session ID in which the database connection is referenced with.
|
||||
"""
|
||||
self.master_lock.acquire()
|
||||
try:
|
||||
# obtain a unique cryptographically random session_id
|
||||
session_id = None
|
||||
while not session_id or session_id in self.db:
|
||||
session_id = os.urandom(16)
|
||||
self.db[session_id] = connection
|
||||
self.cursor[session_id] = self.db[session_id].cursor()
|
||||
self.session_lock[session_id] = Lock()
|
||||
return session_id
|
||||
finally:
|
||||
self.master_lock.release()
|
||||
|
||||
def safe_execute(self, session_id: bytes, action: Callable) -> Any:
|
||||
"""
|
||||
Execute code on the Database Connection in a race-condition safe way.
|
||||
:param session_id: Database Connection's Session ID.
|
||||
:param action: Function or lambda in which to execute, it's provided `db` and `cursor` arguments.
|
||||
:returns: Whatever `action` returns.
|
||||
"""
|
||||
if session_id not in self.db:
|
||||
raise ValueError(f"Session ID {session_id!r} is invalid.")
|
||||
self.master_lock.acquire()
|
||||
self.session_lock[session_id].acquire()
|
||||
try:
|
||||
failures = 0
|
||||
while True:
|
||||
try:
|
||||
action(
|
||||
db=self.db[session_id],
|
||||
cursor=self.cursor[session_id]
|
||||
)
|
||||
break
|
||||
except sqlite3.OperationalError as e:
|
||||
failures += 1
|
||||
delay = 3 * failures
|
||||
print(f"AtomicSQL.safe_execute failed, {e}, retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
if failures == 10:
|
||||
raise ValueError("AtomicSQL.safe_execute failed too many time's. Aborting.")
|
||||
return self.cursor[session_id]
|
||||
finally:
|
||||
self.session_lock[session_id].release()
|
||||
self.master_lock.release()
|
||||
|
||||
def commit(self, session_id: bytes) -> bool:
|
||||
"""
|
||||
Commit changes to the Database Connection immediately.
|
||||
This isn't necessary to be run every time you make changes, just ensure it's run
|
||||
at least before termination.
|
||||
:param session_id: Database Connection's Session ID.
|
||||
:returns: True if it committed.
|
||||
"""
|
||||
self.safe_execute(
|
||||
session_id,
|
||||
lambda db, cursor: db.commit()
|
||||
)
|
||||
return True # todo ; actually check if db.commit worked
|
|
@ -31,11 +31,11 @@ class Vault(metaclass=ABCMeta):
|
|||
"""Get All Keys from Vault by Service."""
|
||||
|
||||
@abstractmethod
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
"""Add KID:KEY to the Vault."""
|
||||
|
||||
@abstractmethod
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
"""
|
||||
Add Multiple Content Keys with Key IDs for Service to the Vault.
|
||||
Pre-existing Content Keys are ignored/skipped.
|
||||
|
|
|
@ -57,7 +57,7 @@ class Vaults:
|
|||
for vault in self.vaults:
|
||||
if vault != excluding:
|
||||
try:
|
||||
success += vault.add_key(self.service, kid, key, commit=True)
|
||||
success += vault.add_key(self.service, kid, key)
|
||||
except (PermissionError, NotImplementedError):
|
||||
pass
|
||||
return success
|
||||
|
@ -70,7 +70,7 @@ class Vaults:
|
|||
success = 0
|
||||
for vault in self.vaults:
|
||||
try:
|
||||
success += bool(vault.add_keys(self.service, kid_keys, commit=True))
|
||||
success += bool(vault.add_keys(self.service, kid_keys))
|
||||
except (PermissionError, NotImplementedError):
|
||||
pass
|
||||
return success
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from queue import Empty, Queue
|
||||
from threading import Lock
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
|
@ -7,7 +10,6 @@ import pymysql
|
|||
from pymysql.cursors import DictCursor
|
||||
|
||||
from devine.core.services import Services
|
||||
from devine.core.utils.atomicsql import AtomicSQL
|
||||
from devine.core.vault import Vault
|
||||
|
||||
|
||||
|
@ -21,15 +23,13 @@ class MySQL(Vault):
|
|||
"""
|
||||
super().__init__(name)
|
||||
self.slug = f"{host}:{database}:{username}"
|
||||
self.con = pymysql.connect(
|
||||
self.con_pool = ConnectionPool(dict(
|
||||
host=host,
|
||||
db=database,
|
||||
user=username,
|
||||
cursorclass=DictCursor,
|
||||
**kwargs
|
||||
)
|
||||
self.adb = AtomicSQL()
|
||||
self.ticket = self.adb.load(self.con)
|
||||
), 5)
|
||||
|
||||
self.permissions = self.get_permissions()
|
||||
if not self.has_permission("SELECT"):
|
||||
|
@ -43,37 +43,44 @@ class MySQL(Vault):
|
|||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=%s AND `key_`!=%s",
|
||||
[kid, "0" * 32]
|
||||
(kid, "0" * 32)
|
||||
)
|
||||
).fetchone()
|
||||
if not c:
|
||||
return None
|
||||
|
||||
return c["key_"]
|
||||
cek = cursor.fetchone()
|
||||
if not cek:
|
||||
return None
|
||||
return cek["key_"]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if not self.has_table(service):
|
||||
# no table, no keys, simple
|
||||
return None
|
||||
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=%s",
|
||||
["0" * 32]
|
||||
("0" * 32,)
|
||||
)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
yield row["kid"], row["key_"]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
for row in c.fetchall():
|
||||
yield row["kid"], row["key_"]
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
|
@ -82,39 +89,38 @@ class MySQL(Vault):
|
|||
|
||||
if not self.has_table(service):
|
||||
try:
|
||||
self.create_table(service, commit)
|
||||
self.create_table(service)
|
||||
except PermissionError:
|
||||
return False
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
if self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id` FROM `{service}` WHERE `kid`=%s AND `key_`=%s",
|
||||
[kid, key]
|
||||
(kid, key)
|
||||
)
|
||||
).fetchone():
|
||||
# table already has this exact KID:KEY stored
|
||||
return True
|
||||
|
||||
self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
if cursor.fetchone():
|
||||
# table already has this exact KID:KEY stored
|
||||
return True
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
||||
(kid, key)
|
||||
)
|
||||
)
|
||||
|
||||
if commit:
|
||||
self.commit()
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
return True
|
||||
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
for kid, key in kid_keys.items():
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
@ -124,7 +130,7 @@ class MySQL(Vault):
|
|||
|
||||
if not self.has_table(service):
|
||||
try:
|
||||
self.create_table(service, commit)
|
||||
self.create_table(service)
|
||||
except PermissionError:
|
||||
return 0
|
||||
|
||||
|
@ -139,40 +145,50 @@ class MySQL(Vault):
|
|||
for kid, key_ in kid_keys.items()
|
||||
}
|
||||
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.executemany(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT IGNORE INTO `{service}` (kid, key_) VALUES (%s, %s)",
|
||||
kid_keys.items()
|
||||
)
|
||||
)
|
||||
|
||||
if commit:
|
||||
self.commit()
|
||||
|
||||
return c.rowcount
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute("SHOW TABLES")
|
||||
)
|
||||
for table in c.fetchall():
|
||||
# each entry has a key named `Tables_in_<db name>`
|
||||
yield Services.get_tag(list(table.values())[0])
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SHOW TABLES")
|
||||
for table in cursor.fetchall():
|
||||
# each entry has a key named `Tables_in_<db name>`
|
||||
yield Services.get_tag(list(table.values())[0])
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def has_table(self, name: str) -> bool:
|
||||
"""Check if the Vault has a Table with the specified name."""
|
||||
return list(self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
|
||||
[self.con.db, name]
|
||||
)
|
||||
).fetchone().values())[0] == 1
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
def create_table(self, name: str, commit: bool = False):
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT count(TABLE_NAME) FROM information_schema.TABLES WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s",
|
||||
(conn.db, name)
|
||||
)
|
||||
return list(cursor.fetchone().values())[0] == 1
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def create_table(self, name: str):
|
||||
"""Create a Table with the specified name if not yet created."""
|
||||
if self.has_table(name):
|
||||
return
|
||||
|
@ -180,9 +196,11 @@ class MySQL(Vault):
|
|||
if not self.has_permission("CREATE"):
|
||||
raise PermissionError(f"MySQL vault {self.slug} has no CREATE permission.")
|
||||
|
||||
self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {name} (
|
||||
|
@ -193,23 +211,30 @@ class MySQL(Vault):
|
|||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if commit:
|
||||
self.commit()
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_permissions(self) -> list:
|
||||
"""Get and parse Grants to a more easily usable list tuple array."""
|
||||
with self.con.cursor() as c:
|
||||
c.execute("SHOW GRANTS")
|
||||
grants = c.fetchall()
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
grants = [next(iter(x.values())) for x in grants]
|
||||
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
|
||||
grants = [(
|
||||
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
|
||||
location.replace("`", "").split(".")
|
||||
) for perms, location in grants]
|
||||
return grants
|
||||
grants = [tuple(x[6:].split(" TO ")[0].split(" ON ")) for x in list(grants)]
|
||||
grants = [(
|
||||
list(map(str.strip, perms.replace("ALL PRIVILEGES", "*").split(","))),
|
||||
location.replace("`", "").split(".")
|
||||
) for perms, location in grants]
|
||||
return grants
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def has_permission(self, operation: str, database: Optional[str] = None, table: Optional[str] = None) -> bool:
|
||||
"""Check if the current connection has a specific permission."""
|
||||
|
@ -220,6 +245,28 @@ class MySQL(Vault):
|
|||
grants = [x for x in grants if x[1][1] in (table, "*")]
|
||||
return bool(grants)
|
||||
|
||||
def commit(self):
|
||||
"""Commit any changes made that has not been written to db."""
|
||||
self.adb.commit(self.ticket)
|
||||
|
||||
class ConnectionPool:
|
||||
def __init__(self, con: dict, size: int):
|
||||
self._con = con
|
||||
self._size = size
|
||||
self._pool = Queue(self._size)
|
||||
self._lock = Lock()
|
||||
|
||||
def _create_connection(self):
|
||||
return pymysql.connect(**self._con)
|
||||
|
||||
def get(self) -> pymysql.Connection:
|
||||
while True:
|
||||
try:
|
||||
return self._pool.get(block=False)
|
||||
except Empty:
|
||||
with self._lock:
|
||||
if self._pool.qsize() < self._size:
|
||||
return self._create_connection()
|
||||
else:
|
||||
# pool full, wait before retrying
|
||||
time.sleep(0.1)
|
||||
|
||||
def put(self, conn: pymysql.Connection):
|
||||
self._pool.put(conn)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from sqlite3 import Connection
|
||||
from threading import Lock
|
||||
from typing import Iterator, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from devine.core.services import Services
|
||||
from devine.core.utils.atomicsql import AtomicSQL
|
||||
from devine.core.vault import Vault
|
||||
|
||||
|
||||
|
@ -17,9 +20,7 @@ class SQLite(Vault):
|
|||
super().__init__(name)
|
||||
self.path = Path(path).expanduser()
|
||||
# TODO: Use a DictCursor or such to get fetches as dict?
|
||||
self.con = sqlite3.connect(self.path)
|
||||
self.adb = AtomicSQL()
|
||||
self.ticket = self.adb.load(self.con)
|
||||
self.con_pool = ConnectionPool(self.path, 5)
|
||||
|
||||
def get_key(self, kid: Union[UUID, str], service: str) -> Optional[str]:
|
||||
if not self.has_table(service):
|
||||
|
@ -29,76 +30,82 @@ class SQLite(Vault):
|
|||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
|
||||
[kid, "0" * 32]
|
||||
)
|
||||
).fetchone()
|
||||
if not c:
|
||||
return None
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
return c[1] # `key_`
|
||||
try:
|
||||
cursor.execute(
|
||||
f"SELECT `id`, `key_` FROM `{service}` WHERE `kid`=? AND `key_`!=?",
|
||||
(kid, "0" * 32)
|
||||
)
|
||||
cek = cursor.fetchone()
|
||||
if not cek:
|
||||
return None
|
||||
return cek[1]
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_keys(self, service: str) -> Iterator[tuple[str, str]]:
|
||||
if not self.has_table(service):
|
||||
# no table, no keys, simple
|
||||
return None
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
|
||||
["0" * 32]
|
||||
)
|
||||
)
|
||||
for (kid, key_) in c.fetchall():
|
||||
yield kid, key_
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str, commit: bool = False) -> bool:
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
f"SELECT `kid`, `key_` FROM `{service}` WHERE `key_`!=?",
|
||||
("0" * 32,)
|
||||
)
|
||||
for (kid, key_) in cursor.fetchall():
|
||||
yield kid, key_
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def add_key(self, service: str, kid: Union[UUID, str], key: str) -> bool:
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if not self.has_table(service):
|
||||
self.create_table(service, commit)
|
||||
self.create_table(service)
|
||||
|
||||
if isinstance(kid, UUID):
|
||||
kid = kid.hex
|
||||
|
||||
if self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"SELECT `id` FROM `{service}` WHERE `kid`=? AND `key_`=?",
|
||||
[kid, key]
|
||||
(kid, key)
|
||||
)
|
||||
).fetchone():
|
||||
# table already has this exact KID:KEY stored
|
||||
return True
|
||||
|
||||
self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
if cursor.fetchone():
|
||||
# table already has this exact KID:KEY stored
|
||||
return True
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT INTO `{service}` (kid, key_) VALUES (?, ?)",
|
||||
(kid, key)
|
||||
)
|
||||
)
|
||||
|
||||
if commit:
|
||||
self.commit()
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
return True
|
||||
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str], commit: bool = False) -> int:
|
||||
def add_keys(self, service: str, kid_keys: dict[Union[UUID, str], str]) -> int:
|
||||
for kid, key in kid_keys.items():
|
||||
if not key or key.count("0") == len(key):
|
||||
raise ValueError("You cannot add a NULL Content Key to a Vault.")
|
||||
|
||||
if not self.has_table(service):
|
||||
self.create_table(service, commit)
|
||||
self.create_table(service)
|
||||
|
||||
if not isinstance(kid_keys, dict):
|
||||
raise ValueError(f"The kid_keys provided is not a dictionary, {kid_keys!r}")
|
||||
|
@ -111,47 +118,59 @@ class SQLite(Vault):
|
|||
for kid, key_ in kid_keys.items()
|
||||
}
|
||||
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.executemany(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.executemany(
|
||||
# TODO: SQL injection risk
|
||||
f"INSERT OR IGNORE INTO `{service}` (kid, key_) VALUES (?, ?)",
|
||||
kid_keys.items()
|
||||
)
|
||||
)
|
||||
|
||||
if commit:
|
||||
self.commit()
|
||||
|
||||
return c.rowcount
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def get_services(self) -> Iterator[str]:
|
||||
c = self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
)
|
||||
for (name,) in c.fetchall():
|
||||
if name != "sqlite_sequence":
|
||||
yield Services.get_tag(name)
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
||||
for (name,) in cursor.fetchall():
|
||||
if name != "sqlite_sequence":
|
||||
yield Services.get_tag(name)
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def has_table(self, name: str) -> bool:
|
||||
"""Check if the Vault has a Table with the specified name."""
|
||||
return self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
|
||||
[name]
|
||||
)
|
||||
).fetchone()[0] == 1
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
def create_table(self, name: str, commit: bool = False):
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT count(name) FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(name,)
|
||||
)
|
||||
return cursor.fetchone()[0] == 1
|
||||
finally:
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
def create_table(self, name: str):
|
||||
"""Create a Table with the specified name if not yet created."""
|
||||
if self.has_table(name):
|
||||
return
|
||||
|
||||
self.adb.safe_execute(
|
||||
self.ticket,
|
||||
lambda db, cursor: cursor.execute(
|
||||
conn = self.con_pool.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
# TODO: SQL injection risk
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {name} (
|
||||
|
@ -163,11 +182,33 @@ class SQLite(Vault):
|
|||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
finally:
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
self.con_pool.put(conn)
|
||||
|
||||
if commit:
|
||||
self.commit()
|
||||
|
||||
def commit(self):
|
||||
"""Commit any changes made that has not been written to db."""
|
||||
self.adb.commit(self.ticket)
|
||||
class ConnectionPool:
|
||||
def __init__(self, path: Union[str, Path], size: int):
|
||||
self._path = path
|
||||
self._size = size
|
||||
self._pool = Queue(self._size)
|
||||
self._lock = Lock()
|
||||
|
||||
def _create_connection(self):
|
||||
return sqlite3.connect(self._path)
|
||||
|
||||
def get(self) -> Connection:
|
||||
while True:
|
||||
try:
|
||||
return self._pool.get(block=False)
|
||||
except Empty:
|
||||
with self._lock:
|
||||
if self._pool.qsize() < self._size:
|
||||
return self._create_connection()
|
||||
else:
|
||||
# pool full, wait before retrying
|
||||
time.sleep(0.1)
|
||||
|
||||
def put(self, conn: Connection):
|
||||
self._pool.put(conn)
|
||||
|
|
Loading…
Reference in New Issue