44# SPDX-License-Identifier: BSD-3-Clause
55
66import abc
7+ import contextlib
78import functools
89import json
910import os
@@ -126,15 +127,24 @@ def _db_connect(self, *args, **kwargs):
126127 with getprofiler ().time_region ('sqlite connect' ):
127128 return sqlite3 .connect (* args , ** kwargs )
128129
129- def _db_lock (self ):
130- return self .__db_lock .write_lock ()
130+ @contextlib .contextmanager
131+ def _db_read (self , * args , ** kwargs ):
132+ with self .__db_lock .read_lock ():
133+ with self ._db_connect (* args , ** kwargs ) as conn :
134+ yield conn
135+
136+ @contextlib .contextmanager
137+ def _db_write (self , * args , ** kwargs ):
138+ with self .__db_lock .write_lock ():
139+ with self ._db_connect (* args , ** kwargs ) as conn :
140+ yield conn
131141
132142 def _db_create (self ):
133143 clsname = type (self ).__name__
134144 getlogger ().debug (
135145 f'{ clsname } : creating results database in { self .__db_file } ...'
136146 )
137- with self ._db_connect (self .__db_file ) as conn :
147+ with self ._db_write (self .__db_file ) as conn :
138148 conn .execute ('CREATE TABLE IF NOT EXISTS sessions('
139149 'uuid TEXT PRIMARY KEY, '
140150 'session_start_unix REAL, '
@@ -159,13 +169,13 @@ def _db_create(self):
159169 os .chmod (self .__db_file , self .__db_file_mode )
160170
161171 def _db_schema_check (self ):
162- with self ._db_connect (self .__db_file ) as conn :
172+ with self ._db_read (self .__db_file ) as conn :
163173 results = conn .execute (
164174 'SELECT schema_version FROM metadata' ).fetchall ()
165175
166176 if not results :
167177 # DB is new, insert the schema version
168- with self ._db_connect (self .__db_file ) as conn :
178+ with self ._db_write (self .__db_file ) as conn :
169179 conn .execute ('INSERT INTO metadata VALUES(:schema_version)' ,
170180 {'schema_version' : self .SCHEMA_VERSION })
171181 else :
@@ -218,9 +228,8 @@ def _db_store_report(self, conn, report, report_file_path):
218228
219229 @time_function
220230 def store (self , report , report_file = None ):
221- with self ._db_lock ():
222- with self ._db_connect (self ._db_file ()) as conn :
223- return self ._db_store_report (conn , report , report_file )
231+ with self ._db_write (self ._db_file ()) as conn :
232+ return self ._db_store_report (conn , report , report_file )
224233
225234 @time_function
226235 def _decode_sessions (self , results , sess_filter ):
@@ -269,7 +278,7 @@ def _mass_json_decode(json_objs):
269278 def _fetch_testcases_raw (self , condition ):
270279 # Retrieve relevant session info and index it in Python
271280 getprofiler ().enter_region ('sqlite session query' )
272- with self ._db_connect (self ._db_file ()) as conn :
281+ with self ._db_read (self ._db_file ()) as conn :
273282 query = ('SELECT uuid, json_blob FROM sessions WHERE uuid IN '
274283 '(SELECT DISTINCT session_uuid FROM testcases '
275284 f'WHERE { condition } )' )
@@ -284,7 +293,7 @@ def _fetch_testcases_raw(self, condition):
284293
285294 # Extract the test case data by extracting their UUIDs
286295 getprofiler ().enter_region ('sqlite testcase query' )
287- with self ._db_connect (self ._db_file ()) as conn :
296+ with self ._db_read (self ._db_file ()) as conn :
288297 query = f'SELECT uuid FROM testcases WHERE { condition } '
289298 getlogger ().debug (query )
290299 conn .create_function ('REGEXP' , 2 , self ._db_matches )
@@ -321,7 +330,7 @@ def _fetch_testcases_from_session(self, selector,
321330 f'session_start_unix < { ts_end } )' )
322331
323332 getprofiler ().enter_region ('sqlite session query' )
324- with self ._db_connect (self ._db_file ()) as conn :
333+ with self ._db_read (self ._db_file ()) as conn :
325334 getlogger ().debug (query )
326335 results = conn .execute (query ).fetchall ()
327336
@@ -375,7 +384,7 @@ def fetch_sessions(self, selector: QuerySelector):
375384 query += f' WHERE uuid == "{ selector .uuid } "'
376385
377386 getprofiler ().enter_region ('sqlite session query' )
378- with self ._db_connect (self ._db_file ()) as conn :
387+ with self ._db_read (self ._db_file ()) as conn :
379388 getlogger ().debug (query )
380389 results = conn .execute (query ).fetchall ()
381390
@@ -423,9 +432,8 @@ def remove_sessions(self, selector: QuerySelector):
423432 uuids = [sess ['session_info' ]['uuid' ]
424433 for sess in self .fetch_sessions (selector )]
425434
426- with self ._db_lock ():
427- with self ._db_connect (self ._db_file ()) as conn :
428- if sqlite3 .sqlite_version_info >= (3 , 35 , 0 ):
429- return self ._do_remove2 (conn , uuids )
430- else :
431- return self ._do_remove (conn , uuids )
435+ with self ._db_write (self ._db_file ()) as conn :
436+ if sqlite3 .sqlite_version_info >= (3 , 35 , 0 ):
437+ return self ._do_remove2 (conn , uuids )
438+ else :
439+ return self ._do_remove (conn , uuids )
0 commit comments