Skip to content

Commit 72027d5

Browse files
committed
Add not_foreign cast flag and nb_type_put() allow_foreign flag so we can prevent mutual recursion between two frameworks failing to perform a cast. Simplify enum destruction. Clean up some things I noticed while updating the pybind11 PR.
1 parent 4fb9c85 commit 72027d5

File tree

9 files changed

+68
-53
lines changed

9 files changed

+68
-53
lines changed

include/nanobind/nb_attr.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,10 @@ enum cast_flags : uint8_t {
207207
// This implies that objects added to the cleanup list may be
208208
// released immediately after the caster's final output value is
209209
// obtained, i.e., before it is used.
210-
manual = (1 << 3)
210+
manual = (1 << 3),
211+
212+
// Disallow satisfying this cast with a foreign framework's binding
213+
not_foreign = (1 << 4),
211214
};
212215

213216

include/nanobind/nb_class.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,6 @@ enum class enum_flags : uint32_t {
205205

206206
/// Is the underlying enumeration type Flag?
207207
is_flag = (1 << 3),
208-
209-
/// Was the enum successfully registered with nanobind?
210-
is_registered = (1 << 4),
211208
};
212209

213210
struct enum_init_data {

include/nanobind/nb_lib.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ NB_CORE bool nb_type_get(const std::type_info *t, PyObject *o, uint8_t flags,
289289
/// Cast a C++ type instance into a Python object
290290
NB_CORE PyObject *nb_type_put(const std::type_info *cpp_type, void *value,
291291
rv_policy rvp, cleanup_list *cleanup,
292-
bool *is_new = nullptr) noexcept;
292+
bool *is_new = nullptr,
293+
bool allow_foreign = true) noexcept;
293294

294295
// Special version of nb_type_put for polymorphic classes
295296
NB_CORE PyObject *nb_type_put_p(const std::type_info *cpp_type,

include/nanobind/stl/unique_ptr.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,12 @@ struct type_caster<std::unique_ptr<T, Deleter>> {
9494
// Stash source python object
9595
src = src_;
9696

97-
// Don't accept foreign types; they can't relinquish ownership
98-
if (!src.is_none() && !inst_check(src))
99-
return false;
100-
101-
/* Try casting to a pointer of the underlying type. We pass flags=0 and
102-
cleanup=nullptr to prevent implicit type conversions (they are
103-
problematic since the instance then wouldn't be owned by 'src') */
104-
return caster.from_python(src_, 0, nullptr);
97+
/* Try casting to a pointer of the underlying type. We pass
98+
cleanup=nullptr and !(flags & convert) to prevent implicit type
99+
conversions, which are problematic since the instance then wouldn't
100+
be owned by 'src'. Also disable casting from a foreign type since it
101+
wouldn't be able to relinquish ownership. */
102+
return caster.from_python(src_, (uint8_t) cast_flags::not_foreign, nullptr);
105103
}
106104

107105
template <typename T2>

src/nb_enum.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
7373
type_init_data *t = (type_init_data *) p;
7474
delete (enum_map *) t->enum_tbl.fwd;
7575
delete (enum_map *) t->enum_tbl.rev;
76-
if (t->flags & (uint32_t) enum_flags::is_registered)
77-
nb_type_unregister(t);
7876
free((char*) t->name);
7977
delete t;
8078
});
@@ -88,7 +86,12 @@ PyObject *enum_create(enum_init_data *ed) noexcept {
8886
return tp;
8987
}
9088

91-
t->flags |= (uint32_t) enum_flags::is_registered;
89+
// Unregister the enum type when it begins being finalized
90+
keep_alive(result.ptr(), t, [](void *p) noexcept {
91+
nb_type_unregister((type_init_data *) p);
92+
});
93+
94+
// Delete typeinfo only when the type's dict is cleared
9295
result.attr("__nb_enum__") = tie_lifetimes;
9396

9497
make_immortal(result.ptr());
@@ -180,7 +183,7 @@ bool enum_from_python(const std::type_info *tp,
180183

181184
#if !defined(NB_DISABLE_INTEROP)
182185
auto try_foreign = [=, &has_foreign]() -> bool {
183-
if (has_foreign) {
186+
if (has_foreign && !(flags & (uint8_t) cast_flags::not_foreign)) {
184187
void *ptr = nb_type_get_foreign(internals, tp, o, flags, cleanup);
185188
if (ptr) {
186189
// Copy from the C++ enum object to our output integer.

src/nb_foreign.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ static void *nb_foreign_from_python(pymb_binding *binding,
4747
PyObject *obj),
4848
void *keep_referenced_ctx) noexcept {
4949
cleanup_list cleanup{nullptr};
50+
uint8_t flags = (uint8_t) cast_flags::not_foreign;
51+
if (convert)
52+
flags |= (uint8_t) cast_flags::convert;
5053
auto *td = (type_data *) binding->context;
5154
if (td->align == 0) { // enum
5255
int64_t value;
5356
if (keep_referenced &&
54-
enum_from_python(td->type, pyobj, &value, td->size,
55-
convert ? uint8_t(cast_flags::convert) : 0,
56-
nullptr)) {
57+
enum_from_python(td->type, pyobj, &value, td->size, flags, nullptr)) {
5758
bytes holder{(uint8_t *) &value + NB_BIG_ENDIAN * (8 - td->size),
5859
td->size};
5960
keep_referenced(keep_referenced_ctx, holder.ptr());
@@ -63,8 +64,7 @@ static void *nb_foreign_from_python(pymb_binding *binding,
6364
}
6465

6566
void *result = nullptr;
66-
bool ok = nb_type_get(td->type, pyobj,
67-
convert ? uint8_t(cast_flags::convert) : 0,
67+
bool ok = nb_type_get(td->type, pyobj, flags,
6868
keep_referenced ? &cleanup : nullptr, &result);
6969
if (keep_referenced) {
7070
// Move temporary references from our `cleanup_list` to our caller's
@@ -109,7 +109,8 @@ static PyObject *nb_foreign_to_python(pymb_binding *binding,
109109
// unless a pyobject wrapper already exists.
110110
rvp = rv_policy::none;
111111
}
112-
return nb_type_put(td->type, cobj, rvp, &cleanup, nullptr);
112+
return nb_type_put(td->type, cobj, rvp, &cleanup,
113+
/* is_new */ nullptr, /* allow_foreign */ false);
113114
}
114115

115116
static int nb_foreign_keep_alive(PyObject *nurse,
@@ -153,10 +154,12 @@ static int nb_foreign_translate_exception(void *eptr) noexcept {
153154
std::rethrow_exception(e);
154155
} catch (python_error &e) {
155156
e.restore();
157+
return 1;
156158
} catch (builtin_exception &e) {
157159
if (!set_builtin_exception_status(e))
158160
PyErr_SetString(PyExc_SystemError, "foreign function threw "
159161
"nanobind::next_overload()");
162+
return 1;
160163
} catch (...) { e = std::current_exception(); }
161164
return 0;
162165
}
@@ -521,10 +524,10 @@ void *nb_type_try_foreign(nb_internals *internals_,
521524
#if defined(NB_FREE_THREADED)
522525
auto per_thread_guard = nb_type_lock_c2p_fast(internals_);
523526
nb_type_map_fast &type_c2p_fast = *per_thread_guard;
524-
uint32_t updates_count = per_thread_guard.updates_count();
525527
#else
526528
nb_type_map_fast &type_c2p_fast = internals_->type_c2p_fast;
527529
#endif
530+
uint32_t update_count = type_c2p_fast.update_count;
528531
do {
529532
// We assume nb_type_c2p already ran for this type, so that there's
530533
// no need to handle a cache miss here.
@@ -563,16 +566,16 @@ void *nb_type_try_foreign(nb_internals *internals_,
563566
return result;
564567

565568
#if defined(NB_FREE_THREADED)
566-
// Re-acquire lock to continue iteration. If we missed an
567-
// update while the lock was released, start our lookup over
568-
// in case the update removed the node we're on.
569+
// Re-acquire lock to continue iteration
569570
per_thread_guard = nb_type_lock_c2p_fast(internals_);
570-
if (per_thread_guard.updates_count() != updates_count) {
571+
#endif
572+
// If we missed an update during attempt(), start our lookup
573+
// over in case the update removed the node we're on.
574+
if (type_c2p_fast.update_count != update_count) {
571575
// Concurrent update occurred; retry
572-
updates_count = per_thread_guard.updates_count();
576+
update_count = type_c2p_fast.update_count;
573577
break;
574578
}
575-
#endif
576579
}
577580
current = current->next;
578581
}

src/nb_internals.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ struct nb_type_map_fast {
331331
/// then return a reference to the stored value, which the caller may
332332
/// modify.
333333
void*& lookup_or_set(const std::type_info *ti, void *dflt) {
334+
++update_count;
334335
return data.try_emplace((void *) ti, dflt).first.value();
335336
}
336337

@@ -348,11 +349,23 @@ struct nb_type_map_fast {
348349
auto it = data.find((void *) ti);
349350
if (it != data.end()) {
350351
it.value() = value;
352+
++update_count;
351353
return true;
352354
}
353355
return false;
354356
}
355357

358+
/// Number of times the map has been modified. Used in nb_type_try_foreign()
359+
/// to detect cases where attempting to use one foreign binding for a type
360+
/// may have invalidated the iterator needed to advance to the next one.
361+
uint32_t update_count = 0;
362+
363+
#if defined(NB_FREE_THREADED)
364+
/// Mutex used by `nb_type_map_per_thread`, stored here because it fits
365+
/// in padding this way.
366+
Py_Mutex mutex{};
367+
#endif
368+
356369
private:
357370
// Use a generic ptr->ptr map to avoid needing another instantiation of
358371
// robin_map. Keys are const std::type_info*. See TYPE MAPPING above for
@@ -390,35 +403,30 @@ struct nb_type_map_per_thread {
390403
}
391404
~guard() {
392405
if (parent)
393-
PyMutex_Unlock(&parent->mutex);
406+
PyMutex_Unlock(&parent->map.mutex);
394407
}
395408

396409
nb_type_map_fast& operator*() const { return parent->map; }
397410
nb_type_map_fast* operator->() const { return &parent->map; }
398411

399-
uint32_t updates_count() const { return parent->updates; }
400-
void note_updated() { ++parent->updates; }
401-
402412
private:
403413
friend nb_type_map_per_thread;
404414
explicit guard(nb_type_map_per_thread &parent_) : parent(&parent_) {
405-
PyMutex_Lock(&parent->mutex);
415+
PyMutex_Lock(&parent->map.mutex);
406416
}
407417
nb_type_map_per_thread *parent = nullptr;
408418
};
409419
guard lock() { return guard{*this}; }
410420

411-
// Mutex protecting accesses to `updates` and `map`
412-
PyMutex mutex{};
413-
414-
// The number of times `map` has been modified
415-
uint32_t updates = 0;
416-
417421
nb_internals &internals;
422+
423+
private:
424+
// Access to the map is only possible via `guard`, which holds a lock
418425
nb_type_map_fast map;
419426

427+
public:
420428
// In order to access or modify `next`, you must hold the nb_internals mutex
421-
// (this->mutex is not needed for iteration)
429+
// (this->map.mutex is not needed for iteration)
422430
nb_type_map_per_thread *next = nullptr;
423431
};
424432
#endif

src/nb_type.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,9 @@ void nb_type_update_c2p_fast(const std::type_info *type, void *value) noexcept {
472472
for (nb_type_map_per_thread *cache =
473473
internals_->type_c2p_per_thread_head;
474474
cache; cache = cache->next) {
475-
auto guard = cache->lock();
476-
bool found = nb_type_update_cache(*guard, it_alias->first,
477-
(nb_alias_seq *) it_alias->second,
478-
value);
479-
if (found)
480-
guard.note_updated();
475+
nb_type_update_cache(*cache->lock(), it_alias->first,
476+
(nb_alias_seq *) it_alias->second,
477+
value);
481478
}
482479
// We can't require that we found a match, because the type might
483480
// have been cached only by a thread that has since exited.
@@ -1621,6 +1618,9 @@ void *nb_type_get_foreign(nb_internals *internals_,
16211618
PyObject *src,
16221619
uint8_t flags,
16231620
cleanup_list *cleanup) noexcept {
1621+
if (flags & (uint8_t) cast_flags::not_foreign)
1622+
return nullptr;
1623+
16241624
struct capture {
16251625
PyObject *src;
16261626
uint8_t flags;
@@ -1993,7 +1993,8 @@ PyObject *nb_type_put_foreign(nb_internals *internals_,
19931993
PyObject *nb_type_put(const std::type_info *cpp_type,
19941994
void *value, rv_policy rvp,
19951995
cleanup_list *cleanup,
1996-
bool *is_new) noexcept {
1996+
bool *is_new,
1997+
bool allow_foreign) noexcept {
19971998
// Convert nullptr -> None
19981999
if (!value) {
19992000
Py_INCREF(Py_None);
@@ -2016,7 +2017,7 @@ PyObject *nb_type_put(const std::type_info *cpp_type,
20162017

20172018
#if !defined(NB_DISABLE_INTEROP)
20182019
auto try_foreign = [=, &has_foreign]() -> PyObject* {
2019-
if (has_foreign)
2020+
if (has_foreign && allow_foreign)
20202021
return nb_type_put_foreign(internals_, cpp_type, nullptr, value,
20212022
rvp, cleanup, is_new);
20222023
return nullptr;

tests/test_inter_module.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,12 @@ def repeatedly_attempt_conversions():
347347
transitions = limit
348348
thread.join()
349349

350-
# typical numbers from my machine: with limit=100, the test takes 6sec,
351-
# and num_failed and num_successful are each several 10k's
350+
# typical numbers from my machine: with limit=5000, the test takes a
351+
# decent fraction of a second, and num_failed and num_successful are each
352+
# several 10k's
352353
print(num_failed, num_successful)
353354
assert num_successful > 0
354-
assert num_failed > 0 or not free_threaded
355+
assert num_failed > 0
355356

356357

357358
def test12_multi_and_implicit(clean):

0 commit comments

Comments
 (0)