@@ -35,13 +35,15 @@ namespace accessor_policies {
35
35
struct sequence_item ;
36
36
struct list_item ;
37
37
struct tuple_item ;
38
+ struct dict_item ;
38
39
} // namespace accessor_policies
39
40
using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
40
41
using str_attr_accessor = accessor<accessor_policies::str_attr>;
41
42
using item_accessor = accessor<accessor_policies::generic_item>;
42
43
using sequence_accessor = accessor<accessor_policies::sequence_item>;
43
44
using list_accessor = accessor<accessor_policies::list_item>;
44
45
using tuple_accessor = accessor<accessor_policies::tuple_item>;
46
+ using dict_accessor = accessor<accessor_policies::dict_item>;
45
47
46
48
// / Tag and check to identify a class which implements the Python object API
47
49
class pyobject_tag { };
@@ -613,6 +615,31 @@ struct tuple_item {
613
615
}
614
616
}
615
617
};
618
+
619
+ struct dict_item {
620
+ using key_type = object;
621
+
622
+ static object get (handle obj, handle key) {
623
+ #if PY_MAJOR_VERSION >= 3
624
+ if (PyObject *result = PyDict_GetItemWithError (obj.ptr (), key.ptr ())) {
625
+ return reinterpret_borrow<object>(result);
626
+ } else {
627
+ if (!PyErr_Occurred ())
628
+ if (PyObject* key_repr = PyObject_Repr (key.ptr ()))
629
+ PyErr_SetObject (PyExc_KeyError, key_repr);
630
+ throw error_already_set ();
631
+ }
632
+ #else
633
+ return generic_item::get (obj, key);
634
+ #endif
635
+ }
636
+
637
+ static void set (handle obj, handle key, handle val) {
638
+ if (PyDict_SetItem (obj.ptr (), key.ptr (), val.ptr ()) != 0 ) {
639
+ throw error_already_set ();
640
+ }
641
+ }
642
+ };
616
643
PYBIND11_NAMESPACE_END (accessor_policies)
617
644
618
645
// / STL iterator template used for tuple, list, sequence and dict
@@ -1285,6 +1312,8 @@ class dict : public object {
1285
1312
1286
1313
size_t size () const { return (size_t ) PyDict_Size (m_ptr); }
1287
1314
bool empty () const { return size () == 0 ; }
1315
+ detail::dict_accessor operator [](const char *key) const { return {*this , pybind11::str (key)}; }
1316
+ detail::dict_accessor operator [](handle h) const { return {*this , reinterpret_borrow<object>(h)}; }
1288
1317
detail::dict_iterator begin () const { return {*this , 0 }; }
1289
1318
detail::dict_iterator end () const { return {}; }
1290
1319
void clear () const { PyDict_Clear (ptr ()); }
@@ -1293,19 +1322,30 @@ class dict : public object {
1293
1322
}
1294
1323
1295
1324
object get (handle key, handle default_ = none()) const {
1296
- if (PyObject *result = PyDict_GetItem (m_ptr, key.ptr ())) {
1325
+ #if PY_MAJOR_VERSION >= 3
1326
+ if (PyObject *result = PyDict_GetItemWithError (m_ptr, key.ptr ())) {
1297
1327
return reinterpret_borrow<object>(result);
1298
1328
} else {
1299
- return reinterpret_borrow<object>(default_);
1329
+ if (PyErr_Occurred ())
1330
+ throw error_already_set ();
1331
+ else
1332
+ return reinterpret_borrow<object>(default_);
1300
1333
}
1334
+ #else
1335
+ try {
1336
+ return object::operator [](key);
1337
+ } catch (const error_already_set& e) {
1338
+ if (e.type ().ptr () == PyExc_KeyError) {
1339
+ return reinterpret_borrow<object>(default_);
1340
+ } else {
1341
+ throw ;
1342
+ }
1343
+ }
1344
+ #endif
1301
1345
}
1302
1346
1303
1347
object get (const char *key, handle default_ = none()) const {
1304
- if (PyObject *result = PyDict_GetItemString (m_ptr, key)) {
1305
- return reinterpret_borrow<object>(result);
1306
- } else {
1307
- return reinterpret_borrow<object>(default_);
1308
- }
1348
+ return get (pybind11::str (key), default_);
1309
1349
}
1310
1350
1311
1351
private:
0 commit comments