Skip to content

Commit 25a464e

Browse files
authored
Fix sqlite3 Cursor initialization check (RustPython#6198)
Add proper __init__ validation for sqlite3.Cursor to ensure base class __init__ is called before using cursor methods. This fixes the test_cursor_constructor_call_check test case. Changes: - Modified Cursor to initialize with inner=None in py_new - Added explicit __init__ method that sets up CursorInner - Updated close() method to check for uninitialized state - Changed error message to match CPython: 'Base Cursor.__init__ not called.' This ensures CPython compatibility where attempting to use a Cursor instance without calling the base __init__ raises ProgrammingError.
1 parent 13329f0 commit 25a464e

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

Lib/test/test_sqlite3/test_regression.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,6 @@ def __del__(self):
195195
con.isolation_level = value
196196
self.assertEqual(con.isolation_level, "DEFERRED")
197197

198-
# TODO: RUSTPYTHON
199-
@unittest.expectedFailure
200198
def test_cursor_constructor_call_check(self):
201199
"""
202200
Verifies that cursor methods check whether base class __init__ was

stdlib/src/sqlite.rs

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,8 @@ mod _sqlite {
14591459
#[pytraverse(skip)]
14601460
rowcount: i64,
14611461
statement: Option<PyRef<Statement>>,
1462+
#[pytraverse(skip)]
1463+
closed: bool,
14621464
}
14631465

14641466
#[derive(FromArgs)]
@@ -1484,20 +1486,54 @@ mod _sqlite {
14841486
lastrowid: -1,
14851487
rowcount: -1,
14861488
statement: None,
1489+
closed: false,
14871490
})),
14881491
}
14891492
}
14901493

1494+
fn new_uninitialized(connection: PyRef<Connection>, _vm: &VirtualMachine) -> Self {
1495+
Self {
1496+
connection,
1497+
arraysize: Radium::new(1),
1498+
row_factory: PyAtomicRef::from(None),
1499+
inner: PyMutex::from(None),
1500+
}
1501+
}
1502+
1503+
#[pymethod]
1504+
fn __init__(&self, _connection: PyRef<Connection>, _vm: &VirtualMachine) -> PyResult<()> {
1505+
let mut guard = self.inner.lock();
1506+
if guard.is_some() {
1507+
// Already initialized (e.g., from a call to super().__init__)
1508+
return Ok(());
1509+
}
1510+
*guard = Some(CursorInner {
1511+
description: None,
1512+
row_cast_map: vec![],
1513+
lastrowid: -1,
1514+
rowcount: -1,
1515+
statement: None,
1516+
closed: false,
1517+
});
1518+
Ok(())
1519+
}
1520+
14911521
fn inner(&self, vm: &VirtualMachine) -> PyResult<PyMappedMutexGuard<'_, CursorInner>> {
14921522
let guard = self.inner.lock();
14931523
if guard.is_some() {
1494-
Ok(PyMutexGuard::map(guard, |x| unsafe {
1495-
x.as_mut().unwrap_unchecked()
1496-
}))
1524+
let inner_guard =
1525+
PyMutexGuard::map(guard, |x| unsafe { x.as_mut().unwrap_unchecked() });
1526+
if inner_guard.closed {
1527+
return Err(new_programming_error(
1528+
vm,
1529+
"Cannot operate on a closed cursor.".to_owned(),
1530+
));
1531+
}
1532+
Ok(inner_guard)
14971533
} else {
14981534
Err(new_programming_error(
14991535
vm,
1500-
"Cannot operate on a closed cursor.".to_owned(),
1536+
"Base Cursor.__init__ not called.".to_owned(),
15011537
))
15021538
}
15031539
}
@@ -1717,12 +1753,23 @@ mod _sqlite {
17171753
}
17181754

17191755
#[pymethod]
1720-
fn close(&self) {
1721-
if let Some(inner) = self.inner.lock().take()
1722-
&& let Some(stmt) = inner.statement
1723-
{
1724-
stmt.lock().reset();
1756+
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
1757+
// Check if __init__ was called
1758+
let mut guard = self.inner.lock();
1759+
if guard.is_none() {
1760+
return Err(new_programming_error(
1761+
vm,
1762+
"Base Cursor.__init__ not called.".to_owned(),
1763+
));
17251764
}
1765+
1766+
if let Some(inner) = guard.as_mut() {
1767+
if let Some(stmt) = &inner.statement {
1768+
stmt.lock().reset();
1769+
}
1770+
inner.closed = true;
1771+
}
1772+
Ok(())
17261773
}
17271774

17281775
#[pymethod]
@@ -1809,7 +1856,7 @@ mod _sqlite {
18091856
type Args = (PyRef<Connection>,);
18101857

18111858
fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
1812-
Self::new(args.0, None, vm)
1859+
Self::new_uninitialized(args.0, vm)
18131860
.into_ref_with_type(vm, cls)
18141861
.map(Into::into)
18151862
}

0 commit comments

Comments
 (0)