Skip to content

Commit bd0a74a

Browse files
authored
Merge pull request #18 from qaspen-python/remove_rwlocks_at
Remove RwLock at connection_pool, connection and transaction
2 parents b19a848 + e9b8ab9 commit bd0a74a

File tree

3 files changed

+98
-78
lines changed

3 files changed

+98
-78
lines changed

src/driver/connection.rs

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,75 @@ use super::{
1515
transaction_options::{IsolationLevel, ReadVariant},
1616
};
1717

18-
#[pyclass]
18+
#[allow(clippy::module_name_repetitions)]
19+
pub struct RustConnection {
20+
pub db_client: Arc<Object>,
21+
}
22+
23+
impl RustConnection {
24+
#[must_use]
25+
pub fn new(db_client: Arc<Object>) -> Self {
26+
RustConnection { db_client }
27+
}
28+
/// Execute statement with or witout parameters.
29+
///
30+
/// # Errors
31+
///
32+
/// May return Err Result if
33+
/// 1) Cannot convert incoming parameters
34+
/// 2) Cannot prepare statement
35+
/// 3) Cannot execute query
36+
pub async fn inner_execute(
37+
&self,
38+
querystring: String,
39+
params: Vec<PythonDTO>,
40+
) -> RustPSQLDriverPyResult<PSQLDriverPyQueryResult> {
41+
let db_client = &self.db_client;
42+
let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(params.len());
43+
for param in &params {
44+
vec_parameters.push(param);
45+
}
46+
let statement: tokio_postgres::Statement = db_client.prepare_cached(&querystring).await?;
47+
48+
let result = db_client
49+
.query(&statement, &vec_parameters.into_boxed_slice())
50+
.await?;
51+
52+
Ok(PSQLDriverPyQueryResult::new(result))
53+
}
54+
55+
/// Return new instance of transaction.
56+
#[must_use]
57+
pub fn inner_transaction(
58+
&self,
59+
isolation_level: Option<IsolationLevel>,
60+
read_variant: Option<ReadVariant>,
61+
deferrable: Option<bool>,
62+
) -> Transaction {
63+
let inner_transaction = RustTransaction::new(
64+
self.db_client.clone(),
65+
Arc::new(tokio::sync::RwLock::new(false)),
66+
Arc::new(tokio::sync::RwLock::new(false)),
67+
Arc::new(tokio::sync::RwLock::new(HashSet::new())),
68+
isolation_level,
69+
read_variant,
70+
deferrable,
71+
);
72+
73+
Transaction::new(Arc::new(inner_transaction), Default::default())
74+
}
75+
}
76+
77+
#[pyclass()]
1978
pub struct Connection {
20-
pub db_client: Arc<tokio::sync::RwLock<Object>>,
79+
pub inner_connection: Arc<RustConnection>,
80+
}
81+
82+
impl Connection {
83+
#[must_use]
84+
pub fn new(inner_connection: Arc<RustConnection>) -> Self {
85+
Connection { inner_connection }
86+
}
2187
}
2288

2389
#[pymethods]
@@ -36,27 +102,14 @@ impl Connection {
36102
querystring: String,
37103
parameters: Option<&'a PyAny>,
38104
) -> RustPSQLDriverPyResult<&PyAny> {
39-
let db_client_arc = self.db_client.clone();
105+
let connection_arc = self.inner_connection.clone();
40106

41107
let mut params: Vec<PythonDTO> = vec![];
42108
if let Some(parameters) = parameters {
43109
params = convert_parameters(parameters)?;
44110
}
45-
46111
rustengine_future(py, async move {
47-
let mut vec_parameters: Vec<&(dyn ToSql + Sync)> = Vec::with_capacity(params.len());
48-
for param in &params {
49-
vec_parameters.push(param);
50-
}
51-
let db_client_guard = db_client_arc.read().await;
52-
let statement: tokio_postgres::Statement =
53-
db_client_guard.prepare_cached(&querystring).await?;
54-
55-
let result = db_client_guard
56-
.query(&statement, &vec_parameters.into_boxed_slice())
57-
.await?;
58-
59-
Ok(PSQLDriverPyQueryResult::new(result))
112+
connection_arc.inner_execute(querystring, params).await
60113
})
61114
}
62115

@@ -68,16 +121,7 @@ impl Connection {
68121
read_variant: Option<ReadVariant>,
69122
deferrable: Option<bool>,
70123
) -> Transaction {
71-
let inner_transaction = RustTransaction::new(
72-
self.db_client.clone(),
73-
Arc::new(tokio::sync::RwLock::new(false)),
74-
Arc::new(tokio::sync::RwLock::new(false)),
75-
Arc::new(tokio::sync::RwLock::new(HashSet::new())),
76-
isolation_level,
77-
read_variant,
78-
deferrable,
79-
);
80-
81-
Transaction::new(Arc::new(inner_transaction), Default::default())
124+
self.inner_connection
125+
.inner_transaction(isolation_level, read_variant, deferrable)
82126
}
83127
}

src/driver/connection_pool.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ use crate::{
1010
value_converter::{convert_parameters, PythonDTO},
1111
};
1212

13-
use super::{common_options::ConnRecyclingMethod, connection::Connection};
13+
use super::{
14+
common_options::ConnRecyclingMethod,
15+
connection::{Connection, RustConnection},
16+
};
1417

1518
/// `PSQLPool` is for internal use only.
1619
///
@@ -70,9 +73,9 @@ impl RustPSQLPool {
7073
.get()
7174
.await?;
7275

73-
Ok(Connection {
74-
db_client: Arc::new(tokio::sync::RwLock::new(db_pool_manager)),
75-
})
76+
Ok(Connection::new(Arc::new(RustConnection::new(
77+
Arc::new(db_pool_manager).clone(),
78+
))))
7679
}
7780
/// Execute querystring with parameters.
7881
///

src/driver/transaction.rs

Lines changed: 19 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use tokio_postgres::{types::ToSql, Row};
2323
/// It is not exposed to python.
2424
#[allow(clippy::module_name_repetitions)]
2525
pub struct RustTransaction {
26-
pub db_client: Arc<tokio::sync::RwLock<Object>>,
26+
pub db_client: Arc<Object>,
2727
is_started: Arc<tokio::sync::RwLock<bool>>,
2828
is_done: Arc<tokio::sync::RwLock<bool>>,
2929
rollback_savepoint: Arc<tokio::sync::RwLock<HashSet<String>>>,
@@ -36,7 +36,7 @@ pub struct RustTransaction {
3636
impl RustTransaction {
3737
#[allow(clippy::too_many_arguments)]
3838
pub fn new(
39-
db_client: Arc<tokio::sync::RwLock<Object>>,
39+
db_client: Arc<Object>,
4040
is_started: Arc<tokio::sync::RwLock<bool>>,
4141
is_done: Arc<tokio::sync::RwLock<bool>>,
4242
rollback_savepoint: Arc<tokio::sync::RwLock<HashSet<String>>>,
@@ -76,11 +76,8 @@ impl RustTransaction {
7676
where
7777
T: ValueOrReferenceTo<Vec<PythonDTO>>,
7878
{
79-
let db_client_arc = self.db_client.clone();
8079
let is_started_arc = self.is_started.clone();
8180
let is_done_arc = self.is_done.clone();
82-
83-
let db_client_guard = db_client_arc.read().await;
8481
let is_started_guard = is_started_arc.read().await;
8582
let is_done_guard = is_done_arc.read().await;
8683

@@ -101,9 +98,10 @@ impl RustTransaction {
10198
vec_parameters.push(param);
10299
}
103100

104-
let statement = db_client_guard.prepare_cached(&querystring).await?;
101+
let statement = self.db_client.prepare_cached(&querystring).await?;
105102

106-
let result = db_client_guard
103+
let result = self
104+
.db_client
107105
.query(&statement, &vec_parameters.into_boxed_slice())
108106
.await?;
109107

@@ -133,11 +131,8 @@ impl RustTransaction {
133131
where
134132
T: ValueOrReferenceTo<Vec<PythonDTO>>,
135133
{
136-
let db_client_arc = self.db_client.clone();
137134
let is_started_arc = self.is_started.clone();
138135
let is_done_arc = self.is_done.clone();
139-
140-
let db_client_guard = db_client_arc.read().await;
141136
let is_started_guard = is_started_arc.read().await;
142137
let is_done_guard = is_done_arc.read().await;
143138

@@ -158,9 +153,10 @@ impl RustTransaction {
158153
vec_parameters.push(param);
159154
}
160155

161-
let statement = db_client_guard.prepare_cached(&querystring).await?;
156+
let statement = self.db_client.prepare_cached(&querystring).await?;
162157

163-
let result = db_client_guard
158+
let result = self
159+
.db_client
164160
.query(&statement, &vec_parameters.into_boxed_slice())
165161
.await?;
166162

@@ -185,11 +181,8 @@ impl RustTransaction {
185181
querystring: String,
186182
parameters: Vec<Vec<PythonDTO>>,
187183
) -> RustPSQLDriverPyResult<()> {
188-
let db_client_arc = self.db_client.clone();
189184
let is_started_arc = self.is_started.clone();
190185
let is_done_arc = self.is_done.clone();
191-
192-
let db_client_guard = db_client_arc.read().await;
193186
let is_started_guard = is_started_arc.read().await;
194187
let is_done_guard = is_done_arc.read().await;
195188

@@ -209,8 +202,8 @@ impl RustTransaction {
209202
));
210203
}
211204
for single_parameters in parameters {
212-
let statement = db_client_guard.prepare_cached(&querystring).await?;
213-
db_client_guard
205+
let statement = self.db_client.prepare_cached(&querystring).await?;
206+
self.db_client
214207
.query(
215208
&statement,
216209
&single_parameters
@@ -243,11 +236,8 @@ impl RustTransaction {
243236
querystring: String,
244237
parameters: Vec<PythonDTO>,
245238
) -> RustPSQLDriverPyResult<PSQLDriverSinglePyQueryResult> {
246-
let db_client_arc = self.db_client.clone();
247239
let is_started_arc = self.is_started.clone();
248240
let is_done_arc = self.is_done.clone();
249-
250-
let db_client_guard = db_client_arc.read().await;
251241
let is_started_guard = is_started_arc.read().await;
252242
let is_done_guard = is_done_arc.read().await;
253243

@@ -267,9 +257,10 @@ impl RustTransaction {
267257
vec_parameters.push(param);
268258
}
269259

270-
let statement = db_client_guard.prepare_cached(&querystring).await?;
260+
let statement = self.db_client.prepare_cached(&querystring).await?;
271261

272-
let result = db_client_guard
262+
let result = self
263+
.db_client
273264
.query_one(&statement, &vec_parameters.into_boxed_slice())
274265
.await?;
275266

@@ -324,10 +315,7 @@ impl RustTransaction {
324315
None => "",
325316
});
326317

327-
let db_client_arc = self.db_client.clone();
328-
let db_client_guard = db_client_arc.read().await;
329-
330-
db_client_guard.batch_execute(&querystring).await?;
318+
self.db_client.batch_execute(&querystring).await?;
331319

332320
Ok(())
333321
}
@@ -384,7 +372,6 @@ impl RustTransaction {
384372
/// 2) Transaction is done
385373
/// 3) Cannot execute `COMMIT` command
386374
pub async fn inner_commit(&self) -> RustPSQLDriverPyResult<()> {
387-
let db_client_arc = self.db_client.clone();
388375
let is_started_arc = self.is_started.clone();
389376
let is_done_arc = self.is_done.clone();
390377

@@ -407,9 +394,7 @@ impl RustTransaction {
407394
"Transaction is already committed or rolled back".into(),
408395
));
409396
}
410-
411-
let db_client_guard = db_client_arc.read().await;
412-
db_client_guard.batch_execute("COMMIT;").await?;
397+
self.db_client.batch_execute("COMMIT;").await?;
413398
let mut is_done_write_guard = is_done_arc.write().await;
414399
*is_done_write_guard = true;
415400

@@ -428,7 +413,6 @@ impl RustTransaction {
428413
/// 3) Specified savepoint name is exists
429414
/// 4) Can not execute SAVEPOINT command
430415
pub async fn inner_savepoint(&self, savepoint_name: String) -> RustPSQLDriverPyResult<()> {
431-
let db_client_arc = self.db_client.clone();
432416
let is_started_arc = self.is_started.clone();
433417
let is_done_arc = self.is_done.clone();
434418

@@ -462,9 +446,7 @@ impl RustTransaction {
462446
"SAVEPOINT name {savepoint_name} is already taken by this transaction",
463447
)));
464448
}
465-
466-
let db_client_guard = db_client_arc.read().await;
467-
db_client_guard
449+
self.db_client
468450
.batch_execute(format!("SAVEPOINT {savepoint_name}").as_str())
469451
.await?;
470452
let mut rollback_savepoint_guard = self.rollback_savepoint.write().await;
@@ -504,10 +486,7 @@ impl RustTransaction {
504486
"Transaction is already committed or rolled back".into(),
505487
));
506488
};
507-
508-
let db_client_arc = self.db_client.clone();
509-
let db_client_guard = db_client_arc.read().await;
510-
db_client_guard.batch_execute("ROLLBACK").await?;
489+
self.db_client.batch_execute("ROLLBACK").await?;
511490
let mut is_done_write_guard = is_done_arc.write().await;
512491
*is_done_write_guard = true;
513492
Ok(())
@@ -556,10 +535,7 @@ impl RustTransaction {
556535
"Don't have rollback with this name".into(),
557536
));
558537
}
559-
560-
let db_client_arc = self.db_client.clone();
561-
let db_client_guard = db_client_arc.read().await;
562-
db_client_guard
538+
self.db_client
563539
.batch_execute(format!("ROLLBACK TO SAVEPOINT {rollback_name}").as_str())
564540
.await?;
565541

@@ -610,10 +586,7 @@ impl RustTransaction {
610586
"Don't have rollback with this name".into(),
611587
));
612588
}
613-
614-
let db_client_arc = self.db_client.clone();
615-
let db_client_guard = db_client_arc.read().await;
616-
db_client_guard
589+
self.db_client
617590
.batch_execute(format!("RELEASE SAVEPOINT {rollback_name}").as_str())
618591
.await?;
619592

0 commit comments

Comments
 (0)