Skip to content

Commit fc71849

Browse files
feat(app): expose a cursor, not a connection in db util
1 parent a19aa3b commit fc71849

File tree

9 files changed

+106
-190
lines changed

9 files changed

+106
-190
lines changed

invokeai/app/services/board_image_records/board_image_records_sqlite.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def add_image_to_board(
2121
board_id: str,
2222
image_name: str,
2323
) -> None:
24-
with self._db.conn() as conn:
25-
conn.execute(
24+
with self._db.transaction() as cursor:
25+
cursor.execute(
2626
"""--sql
2727
INSERT INTO board_images (board_id, image_name)
2828
VALUES (?, ?)
@@ -35,8 +35,8 @@ def remove_image_from_board(
3535
self,
3636
image_name: str,
3737
) -> None:
38-
with self._db.conn() as conn:
39-
conn.execute(
38+
with self._db.transaction() as cursor:
39+
cursor.execute(
4040
"""--sql
4141
DELETE FROM board_images
4242
WHERE image_name = ?;
@@ -50,8 +50,7 @@ def get_images_for_board(
5050
offset: int = 0,
5151
limit: int = 10,
5252
) -> OffsetPaginatedResults[ImageRecord]:
53-
with self._db.conn() as conn:
54-
cursor = conn.cursor()
53+
with self._db.transaction() as cursor:
5554
cursor.execute(
5655
"""--sql
5756
SELECT images.*
@@ -80,8 +79,7 @@ def get_all_board_image_names_for_board(
8079
categories: list[ImageCategory] | None,
8180
is_intermediate: bool | None,
8281
) -> list[str]:
83-
with self._db.conn() as conn:
84-
cursor = conn.cursor()
82+
with self._db.transaction() as cursor:
8583
params: list[str | bool] = []
8684

8785
# Base query is a join between images and board_images
@@ -137,8 +135,7 @@ def get_board_for_image(
137135
self,
138136
image_name: str,
139137
) -> Optional[str]:
140-
with self._db.conn() as conn:
141-
cursor = conn.cursor()
138+
with self._db.transaction() as cursor:
142139
cursor.execute(
143140
"""--sql
144141
SELECT board_id
@@ -153,8 +150,7 @@ def get_board_for_image(
153150
return cast(str, result[0])
154151

155152
def get_image_count_for_board(self, board_id: str) -> int:
156-
with self._db.conn() as conn:
157-
cursor = conn.cursor()
153+
with self._db.transaction() as cursor:
158154
cursor.execute(
159155
"""--sql
160156
SELECT COUNT(*)

invokeai/app/services/board_records/board_records_sqlite.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ def __init__(self, db: SqliteDatabase) -> None:
2323
self._db = db
2424

2525
def delete(self, board_id: str) -> None:
26-
with self._db.conn() as conn:
26+
with self._db.transaction() as cursor:
2727
try:
28-
cursor = conn.cursor()
2928
cursor.execute(
3029
"""--sql
3130
DELETE FROM boards
@@ -40,10 +39,9 @@ def save(
4039
self,
4140
board_name: str,
4241
) -> BoardRecord:
43-
with self._db.conn() as conn:
42+
with self._db.transaction() as cursor:
4443
try:
4544
board_id = uuid_string()
46-
cursor = conn.cursor()
4745
cursor.execute(
4846
"""--sql
4947
INSERT OR IGNORE INTO boards (board_id, board_name)
@@ -59,9 +57,8 @@ def get(
5957
self,
6058
board_id: str,
6159
) -> BoardRecord:
62-
with self._db.conn() as conn:
60+
with self._db.transaction() as cursor:
6361
try:
64-
cursor = conn.cursor()
6562
cursor.execute(
6663
"""--sql
6764
SELECT *
@@ -83,9 +80,8 @@ def update(
8380
board_id: str,
8481
changes: BoardChanges,
8582
) -> BoardRecord:
86-
with self._db.conn() as conn:
83+
with self._db.transaction() as cursor:
8784
try:
88-
cursor = conn.cursor()
8985
# Change the name of a board
9086
if changes.board_name is not None:
9187
cursor.execute(
@@ -131,9 +127,7 @@ def get_many(
131127
limit: int = 10,
132128
include_archived: bool = False,
133129
) -> OffsetPaginatedResults[BoardRecord]:
134-
with self._db.conn() as conn:
135-
cursor = conn.cursor()
136-
130+
with self._db.transaction() as cursor:
137131
# Build base query
138132
base_query = """
139133
SELECT *
@@ -179,8 +173,7 @@ def get_many(
179173
def get_all(
180174
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
181175
) -> list[BoardRecord]:
182-
with self._db.conn() as conn:
183-
cursor = conn.cursor()
176+
with self._db.transaction() as cursor:
184177
if order_by == BoardRecordOrderBy.Name:
185178
base_query = """
186179
SELECT *

invokeai/app/services/image_records/image_records_sqlite.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def __init__(self, db: SqliteDatabase) -> None:
2727
self._db = db
2828

2929
def get(self, image_name: str) -> ImageRecord:
30-
with self._db.conn() as conn:
30+
with self._db.transaction() as cursor:
3131
try:
32-
cursor = conn.cursor()
3332
cursor.execute(
3433
f"""--sql
3534
SELECT {IMAGE_DTO_COLS} FROM images
@@ -48,9 +47,8 @@ def get(self, image_name: str) -> ImageRecord:
4847
return deserialize_image_record(dict(result))
4948

5049
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
51-
with self._db.conn() as conn:
50+
with self._db.transaction() as cursor:
5251
try:
53-
cursor = conn.cursor()
5452
cursor.execute(
5553
"""--sql
5654
SELECT metadata FROM images
@@ -76,9 +74,8 @@ def update(
7674
image_name: str,
7775
changes: ImageRecordChanges,
7876
) -> None:
79-
with self._db.conn() as conn:
77+
with self._db.transaction() as cursor:
8078
try:
81-
cursor = conn.cursor()
8279
# Change the category of the image
8380
if changes.image_category is not None:
8481
cursor.execute(
@@ -138,9 +135,7 @@ def get_many(
138135
board_id: Optional[str] = None,
139136
search_term: Optional[str] = None,
140137
) -> OffsetPaginatedResults[ImageRecord]:
141-
with self._db.conn() as conn:
142-
cursor = conn.cursor()
143-
138+
with self._db.transaction() as cursor:
144139
# Manually build two queries - one for the count, one for the records
145140
count_query = """--sql
146141
SELECT COUNT(*)
@@ -227,20 +222,20 @@ def get_many(
227222
# Build the list of images, deserializing each row
228223
cursor.execute(images_query, images_params)
229224
result = cast(list[sqlite3.Row], cursor.fetchall())
230-
images = [deserialize_image_record(dict(r)) for r in result]
231225

232-
# Set up and execute the count query, without pagination
233-
count_query += query_conditions + ";"
234-
count_params = query_params.copy()
235-
cursor.execute(count_query, count_params)
236-
count = cast(int, cursor.fetchone()[0])
226+
images = [deserialize_image_record(dict(r)) for r in result]
227+
228+
# Set up and execute the count query, without pagination
229+
count_query += query_conditions + ";"
230+
count_params = query_params.copy()
231+
cursor.execute(count_query, count_params)
232+
count = cast(int, cursor.fetchone()[0])
237233

238234
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
239235

240236
def delete(self, image_name: str) -> None:
241-
with self._db.conn() as conn:
237+
with self._db.transaction() as cursor:
242238
try:
243-
cursor = conn.cursor()
244239
cursor.execute(
245240
"""--sql
246241
DELETE FROM images
@@ -252,10 +247,8 @@ def delete(self, image_name: str) -> None:
252247
raise ImageRecordDeleteException from e
253248

254249
def delete_many(self, image_names: list[str]) -> None:
255-
with self._db.conn() as conn:
250+
with self._db.transaction() as cursor:
256251
try:
257-
cursor = conn.cursor()
258-
259252
placeholders = ",".join("?" for _ in image_names)
260253

261254
# Construct the SQLite query with the placeholders
@@ -268,8 +261,7 @@ def delete_many(self, image_names: list[str]) -> None:
268261
raise ImageRecordDeleteException from e
269262

270263
def get_intermediates_count(self) -> int:
271-
with self._db.conn() as conn:
272-
cursor = conn.cursor()
264+
with self._db.transaction() as cursor:
273265
cursor.execute(
274266
"""--sql
275267
SELECT COUNT(*) FROM images
@@ -280,9 +272,8 @@ def get_intermediates_count(self) -> int:
280272
return count
281273

282274
def delete_intermediates(self) -> list[str]:
283-
with self._db.conn() as conn:
275+
with self._db.transaction() as cursor:
284276
try:
285-
cursor = conn.cursor()
286277
cursor.execute(
287278
"""--sql
288279
SELECT image_name FROM images
@@ -315,9 +306,8 @@ def save(
315306
node_id: Optional[str] = None,
316307
metadata: Optional[str] = None,
317308
) -> datetime:
318-
with self._db.conn() as conn:
309+
with self._db.transaction() as cursor:
319310
try:
320-
cursor = conn.cursor()
321311
cursor.execute(
322312
"""--sql
323313
INSERT OR IGNORE INTO images (
@@ -366,8 +356,7 @@ def save(
366356
return created_at
367357

368358
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
369-
with self._db.conn() as conn:
370-
cursor = conn.cursor()
359+
with self._db.transaction() as cursor:
371360
cursor.execute(
372361
"""--sql
373362
SELECT images.*
@@ -398,9 +387,7 @@ def get_image_names(
398387
board_id: Optional[str] = None,
399388
search_term: Optional[str] = None,
400389
) -> ImageNamesResult:
401-
with self._db.conn() as conn:
402-
cursor = conn.cursor()
403-
390+
with self._db.transaction() as cursor:
404391
# Build query conditions (reused for both starred count and image names queries)
405392
query_conditions = ""
406393
query_params: list[Union[int, str, bool]] = []

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,8 @@ def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
8888
8989
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
9090
"""
91-
with self._db.conn() as conn:
91+
with self._db.transaction() as cursor:
9292
try:
93-
cursor = conn.cursor()
9493
cursor.execute(
9594
"""--sql
9695
INSERT INTO models (
@@ -127,8 +126,7 @@ def del_model(self, key: str) -> None:
127126
128127
Can raise an UnknownModelException
129128
"""
130-
with self._db.conn() as conn:
131-
cursor = conn.cursor()
129+
with self._db.transaction() as cursor:
132130
cursor.execute(
133131
"""--sql
134132
DELETE FROM models
@@ -140,7 +138,7 @@ def del_model(self, key: str) -> None:
140138
raise UnknownModelException("model not found")
141139

142140
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
143-
with self._db.conn() as conn:
141+
with self._db.transaction() as cursor:
144142
record = self.get_model(key)
145143

146144
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
@@ -149,7 +147,6 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
149147

150148
json_serialized = record.model_dump_json()
151149

152-
cursor = conn.cursor()
153150
cursor.execute(
154151
"""--sql
155152
UPDATE models
@@ -172,8 +169,7 @@ def get_model(self, key: str) -> AnyModelConfig:
172169
173170
Exceptions: UnknownModelException
174171
"""
175-
with self._db.conn() as conn:
176-
cursor = conn.cursor()
172+
with self._db.transaction() as cursor:
177173
cursor.execute(
178174
"""--sql
179175
SELECT config, strftime('%s',updated_at) FROM models
@@ -188,8 +184,7 @@ def get_model(self, key: str) -> AnyModelConfig:
188184
return model
189185

190186
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
191-
with self._db.conn() as conn:
192-
cursor = conn.cursor()
187+
with self._db.transaction() as cursor:
193188
cursor.execute(
194189
"""--sql
195190
SELECT config, strftime('%s',updated_at) FROM models
@@ -209,8 +204,7 @@ def exists(self, key: str) -> bool:
209204
210205
:param key: Unique key for the model to be deleted
211206
"""
212-
with self._db.conn() as conn:
213-
cursor = conn.cursor()
207+
with self._db.transaction() as cursor:
214208
cursor.execute(
215209
"""--sql
216210
select count(*) FROM models
@@ -241,7 +235,7 @@ def search_by_attr(
241235
If none of the optional filters are passed, will return all
242236
models in the database.
243237
"""
244-
with self._db.conn() as conn:
238+
with self._db.transaction() as cursor:
245239
assert isinstance(order_by, ModelRecordOrderBy)
246240
ordering = {
247241
ModelRecordOrderBy.Default: "type, base, name, format",
@@ -267,7 +261,6 @@ def search_by_attr(
267261
bindings.append(model_format)
268262
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
269263

270-
cursor = conn.cursor()
271264
cursor.execute(
272265
f"""--sql
273266
SELECT config, strftime('%s',updated_at)
@@ -299,8 +292,7 @@ def search_by_attr(
299292

300293
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
301294
"""Return models with the indicated path."""
302-
with self._db.conn() as conn:
303-
cursor = conn.cursor()
295+
with self._db.transaction() as cursor:
304296
cursor.execute(
305297
"""--sql
306298
SELECT config, strftime('%s',updated_at) FROM models
@@ -313,8 +305,7 @@ def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
313305

314306
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
315307
"""Return models with the indicated hash."""
316-
with self._db.conn() as conn:
317-
cursor = conn.cursor()
308+
with self._db.transaction() as cursor:
318309
cursor.execute(
319310
"""--sql
320311
SELECT config, strftime('%s',updated_at) FROM models
@@ -329,7 +320,7 @@ def list_models(
329320
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
330321
) -> PaginatedResults[ModelSummary]:
331322
"""Return a paginated summary listing of each model in the database."""
332-
with self._db.conn() as conn:
323+
with self._db.transaction() as cursor:
333324
assert isinstance(order_by, ModelRecordOrderBy)
334325
ordering = {
335326
ModelRecordOrderBy.Default: "type, base, name, format",
@@ -339,8 +330,6 @@ def list_models(
339330
ModelRecordOrderBy.Format: "format",
340331
}
341332

342-
cursor = conn.cursor()
343-
344333
# Lock so that the database isn't updated while we're doing the two queries.
345334
# query1: get the total number of model configs
346335
cursor.execute(

0 commit comments

Comments
 (0)