164 lines
5.2 KiB
Python
164 lines
5.2 KiB
Python
import sqlite3
|
|
import weakref
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Set
|
|
import threading
|
|
from overrides import override
|
|
from typing_extensions import Annotated
|
|
|
|
|
|
class Connection:
|
|
"""A threadpool connection that returns itself to the pool on close()"""
|
|
|
|
_pool: "Pool"
|
|
_db_file: str
|
|
_conn: sqlite3.Connection
|
|
|
|
def __init__(
|
|
self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any
|
|
):
|
|
self._pool = pool
|
|
self._db_file = db_file
|
|
self._conn = sqlite3.connect(
|
|
db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs
|
|
) # type: ignore
|
|
self._conn.isolation_level = None # Handle commits explicitly
|
|
|
|
def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore
|
|
if parameters is ...:
|
|
return self._conn.execute(sql)
|
|
return self._conn.execute(sql, parameters)
|
|
|
|
def commit(self) -> None:
|
|
self._conn.commit()
|
|
|
|
def rollback(self) -> None:
|
|
self._conn.rollback()
|
|
|
|
def cursor(self) -> sqlite3.Cursor:
|
|
return self._conn.cursor()
|
|
|
|
def close_actual(self) -> None:
|
|
"""Actually closes the connection to the db"""
|
|
self._conn.close()
|
|
|
|
|
|
class Pool(ABC):
|
|
"""Abstract base class for a pool of connections to a sqlite database."""
|
|
|
|
@abstractmethod
|
|
def __init__(self, db_file: str, is_uri: bool) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
|
"""Return a connection from the pool."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def close(self) -> None:
|
|
"""Close all connections in the pool."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def return_to_pool(self, conn: Connection) -> None:
|
|
"""Return a connection to the pool."""
|
|
pass
|
|
|
|
|
|
class LockPool(Pool):
|
|
"""A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time.
|
|
This is used because sqlite does not support multithreaded access with connection timeouts when using the
|
|
shared cache mode. We use the shared cache mode to allow multiple threads to share a database.
|
|
"""
|
|
|
|
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
|
|
_lock: threading.RLock
|
|
_connection: threading.local
|
|
_db_file: str
|
|
_is_uri: bool
|
|
|
|
def __init__(self, db_file: str, is_uri: bool = False):
|
|
self._connections = set()
|
|
self._connection = threading.local()
|
|
self._lock = threading.RLock()
|
|
self._db_file = db_file
|
|
self._is_uri = is_uri
|
|
|
|
@override
|
|
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
|
self._lock.acquire()
|
|
if hasattr(self._connection, "conn") and self._connection.conn is not None:
|
|
return self._connection.conn # type: ignore # cast doesn't work here for some reason
|
|
else:
|
|
new_connection = Connection(
|
|
self, self._db_file, self._is_uri, *args, **kwargs
|
|
)
|
|
self._connection.conn = new_connection
|
|
self._connections.add(weakref.ref(new_connection))
|
|
return new_connection
|
|
|
|
@override
|
|
def return_to_pool(self, conn: Connection) -> None:
|
|
try:
|
|
self._lock.release()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
@override
|
|
def close(self) -> None:
|
|
for conn in self._connections:
|
|
if conn() is not None:
|
|
conn().close_actual() # type: ignore
|
|
self._connections.clear()
|
|
self._connection = threading.local()
|
|
try:
|
|
self._lock.release()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
|
|
class PerThreadPool(Pool):
|
|
"""Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be
|
|
extended to do so and block on connect() if the cap is reached.
|
|
"""
|
|
|
|
_connections: Set[Annotated[weakref.ReferenceType, Connection]]
|
|
_lock: threading.Lock
|
|
_connection: threading.local
|
|
_db_file: str
|
|
_is_uri_: bool
|
|
|
|
def __init__(self, db_file: str, is_uri: bool = False):
|
|
self._connections = set()
|
|
self._connection = threading.local()
|
|
self._lock = threading.Lock()
|
|
self._db_file = db_file
|
|
self._is_uri = is_uri
|
|
|
|
@override
|
|
def connect(self, *args: Any, **kwargs: Any) -> Connection:
|
|
if hasattr(self._connection, "conn") and self._connection.conn is not None:
|
|
return self._connection.conn # type: ignore # cast doesn't work here for some reason
|
|
else:
|
|
new_connection = Connection(
|
|
self, self._db_file, self._is_uri, *args, **kwargs
|
|
)
|
|
self._connection.conn = new_connection
|
|
with self._lock:
|
|
self._connections.add(weakref.ref(new_connection))
|
|
return new_connection
|
|
|
|
@override
|
|
def close(self) -> None:
|
|
with self._lock:
|
|
for conn in self._connections:
|
|
if conn() is not None:
|
|
conn().close_actual() # type: ignore
|
|
self._connections.clear()
|
|
self._connection = threading.local()
|
|
|
|
@override
|
|
def return_to_pool(self, conn: Connection) -> None:
|
|
pass # Each thread gets its own connection, so we don't need to return it to the pool
|