Skip to content

Commit 5f0aead

Browse files
committed
dict_accessor for optimized operator [] of dict
1 parent e0aa141 commit 5f0aead

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

include/pybind11/pytypes.h

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ namespace accessor_policies {
3535
struct sequence_item;
3636
struct list_item;
3737
struct tuple_item;
38+
struct dict_item;
3839
} // namespace accessor_policies
3940
using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
4041
using str_attr_accessor = accessor<accessor_policies::str_attr>;
4142
using item_accessor = accessor<accessor_policies::generic_item>;
4243
using sequence_accessor = accessor<accessor_policies::sequence_item>;
4344
using list_accessor = accessor<accessor_policies::list_item>;
4445
using tuple_accessor = accessor<accessor_policies::tuple_item>;
46+
using dict_accessor = accessor<accessor_policies::dict_item>;
4547

4648
/// Tag and check to identify a class which implements the Python object API
4749
class pyobject_tag { };
@@ -613,6 +615,31 @@ struct tuple_item {
613615
}
614616
}
615617
};
618+
619+
struct dict_item {
620+
using key_type = object;
621+
622+
static object get(handle obj, handle key) {
623+
#if PY_MAJOR_VERSION >= 3
624+
if (PyObject *result = PyDict_GetItemWithError(obj.ptr(), key.ptr())) {
625+
return reinterpret_borrow<object>(result);
626+
} else {
627+
if (!PyErr_Occurred())
628+
if (PyObject* key_repr = PyObject_Repr(key.ptr()))
629+
PyErr_SetObject(PyExc_KeyError, key_repr);
630+
throw error_already_set();
631+
}
632+
#else
633+
return generic_item::get(obj, key);
634+
#endif
635+
}
636+
637+
static void set(handle obj, handle key, handle val) {
638+
if (PyDict_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) {
639+
throw error_already_set();
640+
}
641+
}
642+
};
616643
PYBIND11_NAMESPACE_END(accessor_policies)
617644

618645
/// STL iterator template used for tuple, list, sequence and dict
@@ -1285,6 +1312,8 @@ class dict : public object {
12851312

12861313
size_t size() const { return (size_t) PyDict_Size(m_ptr); }
12871314
bool empty() const { return size() == 0; }
1315+
detail::dict_accessor operator[](const char *key) const { return {*this, pybind11::str(key)}; }
1316+
detail::dict_accessor operator[](handle h) const { return {*this, reinterpret_borrow<object>(h)}; }
12881317
detail::dict_iterator begin() const { return {*this, 0}; }
12891318
detail::dict_iterator end() const { return {}; }
12901319
void clear() const { PyDict_Clear(ptr()); }
@@ -1293,19 +1322,30 @@ class dict : public object {
12931322
}
12941323

12951324
object get(handle key, handle default_ = none()) const {
1296-
if (PyObject *result = PyDict_GetItem(m_ptr, key.ptr())) {
1325+
#if PY_MAJOR_VERSION >= 3
1326+
if (PyObject *result = PyDict_GetItemWithError(m_ptr, key.ptr())) {
12971327
return reinterpret_borrow<object>(result);
12981328
} else {
1299-
return reinterpret_borrow<object>(default_);
1329+
if (PyErr_Occurred())
1330+
throw error_already_set();
1331+
else
1332+
return reinterpret_borrow<object>(default_);
13001333
}
1334+
#else
1335+
try {
1336+
return object::operator[](key);
1337+
} catch (const error_already_set& e) {
1338+
if (e.type().ptr() == PyExc_KeyError) {
1339+
return reinterpret_borrow<object>(default_);
1340+
} else {
1341+
throw;
1342+
}
1343+
}
1344+
#endif
13011345
}
13021346

13031347
object get(const char *key, handle default_ = none()) const {
1304-
if (PyObject *result = PyDict_GetItemString(m_ptr, key)) {
1305-
return reinterpret_borrow<object>(result);
1306-
} else {
1307-
return reinterpret_borrow<object>(default_);
1308-
}
1348+
return get(pybind11::str(key), default_);
13091349
}
13101350

13111351
private:

tests/test_pytypes.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,22 @@ TEST_SUBMODULE(pytypes, m) {
6868
auto d2 = py::dict("z"_a=3, **d1);
6969
return d2;
7070
});
71-
m.def("dict_contains", [](py::dict dict, py::object val) {
71+
m.def("dict_contains", [](py::dict dict, const char* val) {
7272
return dict.contains(val);
7373
});
74-
m.def("dict_contains", [](py::dict dict, const char* val) {
74+
m.def("dict_contains", [](py::dict dict, py::object val) {
7575
return dict.contains(val);
7676
});
77-
m.def("dict_get", [](py::dict dict, py::object key, py::object default_) {
77+
m.def("dict_get", [](py::dict dict, const char* key, py::object default_) {
7878
return dict.get(key, default_);
7979
});
80-
m.def("dict_get", [](py::dict dict, const char* key, py::object default_) {
80+
m.def("dict_get", [](py::dict dict, py::object key, py::object default_) {
8181
return dict.get(key, default_);
8282
});
83-
m.def("dict_get", [](py::dict dict, py::object key) {
83+
m.def("dict_get", [](py::dict dict, const char* key) {
8484
return dict.get(key);
8585
});
86-
m.def("dict_get", [](py::dict dict, const char* key) {
86+
m.def("dict_get", [](py::dict dict, py::object key) {
8787
return dict.get(key);
8888
});
8989

0 commit comments

Comments
 (0)