Skip to content

Commit a19aa3b

Browse files
feat(app): db abstraction to prevent threading conflicts
- Add a context manager to the SqliteDatabase class which abstracts away creating a transaction, committing it on success and rolling back on error. - Use it everywhere. The context manager should be exited before returning results. No business logic changes should be present.
1 parent ef4d5d7 commit a19aa3b

File tree

11 files changed

+1336
-1409
lines changed

11 files changed

+1336
-1409
lines changed

invokeai/app/services/board_image_records/board_image_records_sqlite.py

Lines changed: 100 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -14,73 +14,63 @@
1414
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
1515
def __init__(self, db: SqliteDatabase) -> None:
1616
super().__init__()
17-
self._conn = db.conn
17+
self._db = db
1818

1919
def add_image_to_board(
2020
self,
2121
board_id: str,
2222
image_name: str,
2323
) -> None:
24-
try:
25-
cursor = self._conn.cursor()
26-
cursor.execute(
24+
with self._db.conn() as conn:
25+
conn.execute(
2726
"""--sql
2827
INSERT INTO board_images (board_id, image_name)
2928
VALUES (?, ?)
3029
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
3130
""",
3231
(board_id, image_name, board_id),
3332
)
34-
self._conn.commit()
35-
except sqlite3.Error as e:
36-
self._conn.rollback()
37-
raise e
3833

3934
def remove_image_from_board(
4035
self,
4136
image_name: str,
4237
) -> None:
43-
try:
44-
cursor = self._conn.cursor()
45-
cursor.execute(
38+
with self._db.conn() as conn:
39+
conn.execute(
4640
"""--sql
4741
DELETE FROM board_images
4842
WHERE image_name = ?;
4943
""",
5044
(image_name,),
5145
)
52-
self._conn.commit()
53-
except sqlite3.Error as e:
54-
self._conn.rollback()
55-
raise e
5646

5747
def get_images_for_board(
5848
self,
5949
board_id: str,
6050
offset: int = 0,
6151
limit: int = 10,
6252
) -> OffsetPaginatedResults[ImageRecord]:
63-
# TODO: this isn't paginated yet?
64-
cursor = self._conn.cursor()
65-
cursor.execute(
66-
"""--sql
67-
SELECT images.*
68-
FROM board_images
69-
INNER JOIN images ON board_images.image_name = images.image_name
70-
WHERE board_images.board_id = ?
71-
ORDER BY board_images.updated_at DESC;
72-
""",
73-
(board_id,),
74-
)
75-
result = cast(list[sqlite3.Row], cursor.fetchall())
76-
images = [deserialize_image_record(dict(r)) for r in result]
77-
78-
cursor.execute(
79-
"""--sql
80-
SELECT COUNT(*) FROM images WHERE 1=1;
81-
"""
82-
)
83-
count = cast(int, cursor.fetchone()[0])
53+
with self._db.conn() as conn:
54+
cursor = conn.cursor()
55+
cursor.execute(
56+
"""--sql
57+
SELECT images.*
58+
FROM board_images
59+
INNER JOIN images ON board_images.image_name = images.image_name
60+
WHERE board_images.board_id = ?
61+
ORDER BY board_images.updated_at DESC;
62+
""",
63+
(board_id,),
64+
)
65+
result = cast(list[sqlite3.Row], cursor.fetchall())
66+
images = [deserialize_image_record(dict(r)) for r in result]
67+
68+
cursor.execute(
69+
"""--sql
70+
SELECT COUNT(*) FROM images WHERE 1=1;
71+
"""
72+
)
73+
count = cast(int, cursor.fetchone()[0])
8474

8575
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
8676

@@ -90,88 +80,90 @@ def get_all_board_image_names_for_board(
9080
categories: list[ImageCategory] | None,
9181
is_intermediate: bool | None,
9282
) -> list[str]:
93-
params: list[str | bool] = []
94-
95-
# Base query is a join between images and board_images
96-
stmt = """
97-
SELECT images.image_name
98-
FROM images
99-
LEFT JOIN board_images ON board_images.image_name = images.image_name
100-
WHERE 1=1
101-
"""
102-
103-
# Handle board_id filter
104-
if board_id == "none":
105-
stmt += """--sql
106-
AND board_images.board_id IS NULL
107-
"""
108-
else:
109-
stmt += """--sql
110-
AND board_images.board_id = ?
111-
"""
112-
params.append(board_id)
113-
114-
# Add the category filter
115-
if categories is not None:
116-
# Convert the enum values to unique list of strings
117-
category_strings = [c.value for c in set(categories)]
118-
# Create the correct length of placeholders
119-
placeholders = ",".join("?" * len(category_strings))
120-
stmt += f"""--sql
121-
AND images.image_category IN ( {placeholders} )
122-
"""
123-
124-
# Unpack the included categories into the query params
125-
for c in category_strings:
126-
params.append(c)
127-
128-
# Add the is_intermediate filter
129-
if is_intermediate is not None:
130-
stmt += """--sql
131-
AND images.is_intermediate = ?
132-
"""
133-
params.append(is_intermediate)
134-
135-
# Put a ring on it
136-
stmt += ";"
137-
138-
# Execute the query
139-
cursor = self._conn.cursor()
140-
cursor.execute(stmt, params)
141-
142-
result = cast(list[sqlite3.Row], cursor.fetchall())
83+
with self._db.conn() as conn:
84+
cursor = conn.cursor()
85+
params: list[str | bool] = []
86+
87+
# Base query is a join between images and board_images
88+
stmt = """
89+
SELECT images.image_name
90+
FROM images
91+
LEFT JOIN board_images ON board_images.image_name = images.image_name
92+
WHERE 1=1
93+
"""
94+
95+
# Handle board_id filter
96+
if board_id == "none":
97+
stmt += """--sql
98+
AND board_images.board_id IS NULL
99+
"""
100+
else:
101+
stmt += """--sql
102+
AND board_images.board_id = ?
103+
"""
104+
params.append(board_id)
105+
106+
# Add the category filter
107+
if categories is not None:
108+
# Convert the enum values to unique list of strings
109+
category_strings = [c.value for c in set(categories)]
110+
# Create the correct length of placeholders
111+
placeholders = ",".join("?" * len(category_strings))
112+
stmt += f"""--sql
113+
AND images.image_category IN ( {placeholders} )
114+
"""
115+
116+
# Unpack the included categories into the query params
117+
for c in category_strings:
118+
params.append(c)
119+
120+
# Add the is_intermediate filter
121+
if is_intermediate is not None:
122+
stmt += """--sql
123+
AND images.is_intermediate = ?
124+
"""
125+
params.append(is_intermediate)
126+
127+
# Put a ring on it
128+
stmt += ";"
129+
130+
cursor.execute(stmt, params)
131+
132+
result = cast(list[sqlite3.Row], cursor.fetchall())
143133
image_names = [r[0] for r in result]
144134
return image_names
145135

146136
def get_board_for_image(
147137
self,
148138
image_name: str,
149139
) -> Optional[str]:
150-
cursor = self._conn.cursor()
151-
cursor.execute(
152-
"""--sql
153-
SELECT board_id
154-
FROM board_images
155-
WHERE image_name = ?;
156-
""",
157-
(image_name,),
158-
)
159-
result = cursor.fetchone()
140+
with self._db.conn() as conn:
141+
cursor = conn.cursor()
142+
cursor.execute(
143+
"""--sql
144+
SELECT board_id
145+
FROM board_images
146+
WHERE image_name = ?;
147+
""",
148+
(image_name,),
149+
)
150+
result = cursor.fetchone()
160151
if result is None:
161152
return None
162153
return cast(str, result[0])
163154

164155
def get_image_count_for_board(self, board_id: str) -> int:
165-
cursor = self._conn.cursor()
166-
cursor.execute(
167-
"""--sql
168-
SELECT COUNT(*)
169-
FROM board_images
170-
INNER JOIN images ON board_images.image_name = images.image_name
171-
WHERE images.is_intermediate = FALSE
172-
AND board_images.board_id = ?;
173-
""",
174-
(board_id,),
175-
)
176-
count = cast(int, cursor.fetchone()[0])
156+
with self._db.conn() as conn:
157+
cursor = conn.cursor()
158+
cursor.execute(
159+
"""--sql
160+
SELECT COUNT(*)
161+
FROM board_images
162+
INNER JOIN images ON board_images.image_name = images.image_name
163+
WHERE images.is_intermediate = FALSE
164+
AND board_images.board_id = ?;
165+
""",
166+
(board_id,),
167+
)
168+
count = cast(int, cursor.fetchone()[0])
177169
return count

0 commit comments

Comments
 (0)