Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ namespace accessor_policies {
struct sequence_item;
struct list_item;
struct tuple_item;
struct dict_item;
} // namespace accessor_policies
using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
using str_attr_accessor = accessor<accessor_policies::str_attr>;
using item_accessor = accessor<accessor_policies::generic_item>;
using sequence_accessor = accessor<accessor_policies::sequence_item>;
using list_accessor = accessor<accessor_policies::list_item>;
using tuple_accessor = accessor<accessor_policies::tuple_item>;
using dict_accessor = accessor<accessor_policies::dict_item>;

/// Tag and check to identify a class which implements the Python object API
class pyobject_tag { };
Expand Down Expand Up @@ -613,6 +615,31 @@ struct tuple_item {
}
}
};

struct dict_item {
using key_type = object;

static object get(handle obj, handle key) {
#if PY_MAJOR_VERSION >= 3
if (PyObject *result = PyDict_GetItemWithError(obj.ptr(), key.ptr())) {
return reinterpret_borrow<object>(result);
} else {
if (!PyErr_Occurred())
if (PyObject* key_repr = PyObject_Repr(key.ptr()))
PyErr_SetObject(PyExc_KeyError, key_repr);
throw error_already_set();
}
#else
return generic_item::get(obj, key);
#endif
}

static void set(handle obj, handle key, handle val) {
if (PyDict_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) {
throw error_already_set();
}
}
};
PYBIND11_NAMESPACE_END(accessor_policies)

/// STL iterator template used for tuple, list, sequence and dict
Expand Down Expand Up @@ -1285,13 +1312,42 @@ class dict : public object {

size_t size() const { return (size_t) PyDict_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::dict_accessor operator[](const char *key) const { return {*this, pybind11::str(key)}; }
detail::dict_accessor operator[](handle h) const { return {*this, reinterpret_borrow<object>(h)}; }
detail::dict_iterator begin() const { return {*this, 0}; }
detail::dict_iterator end() const { return {}; }
void clear() const { PyDict_Clear(ptr()); }
template <typename T> bool contains(T &&key) const {
return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr()) == 1;
}

object get(handle key, handle default_ = none()) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other reviewers: should this be moved out of the class definition? I think it's on the edge of being too long, but good enough and maybe nicer to keep it just as-is than to split it up into two parts.

#if PY_MAJOR_VERSION >= 3
if (PyObject *result = PyDict_GetItemWithError(m_ptr, key.ptr())) {
return reinterpret_borrow<object>(result);
} else {
if (PyErr_Occurred())
throw error_already_set();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw or return default?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else
return reinterpret_borrow<object>(default_);
}
#else
try {
return object::operator[](key);
} catch (const error_already_set& e) {
if (e.type().ptr() == PyExc_KeyError) {
return reinterpret_borrow<object>(default_);
} else {
throw;
}
}
#endif
}

object get(const char *key, handle default_ = none()) const {
return get(pybind11::str(key), default_);
}

private:
/// Call the `dict` Python type -- always returns a new reference
static PyObject *raw_dict(PyObject *op) {
Expand Down
16 changes: 14 additions & 2 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,24 @@ TEST_SUBMODULE(pytypes, m) {
auto d2 = py::dict("z"_a=3, **d1);
return d2;
});
m.def("dict_contains", [](py::dict dict, py::object val) {
m.def("dict_contains", [](py::dict dict, const char* val) {
return dict.contains(val);
});
m.def("dict_contains", [](py::dict dict, const char* val) {
m.def("dict_contains", [](py::dict dict, py::object val) {
return dict.contains(val);
});
m.def("dict_get", [](py::dict dict, const char* key, py::object default_) {
return dict.get(key, default_);
});
m.def("dict_get", [](py::dict dict, py::object key, py::object default_) {
return dict.get(key, default_);
});
m.def("dict_get", [](py::dict dict, const char* key) {
return dict.get(key);
});
m.def("dict_get", [](py::dict dict, py::object key) {
return dict.get(key);
});

// test_str
m.def("str_from_string", []() { return py::str(std::string("baz")); });
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ def test_dict(capture, doc):
assert m.dict_contains({42: None}, 42)
assert m.dict_contains({"foo": None}, "foo")

d2 = {True: 42, 3: "abc", "abc": 3}
assert m.dict_get(d2, True) == 42
assert m.dict_get(d2, False) is None
assert m.dict_get(d2, 3, "def") == "abc"
assert m.dict_get(d2, 5, "def") == "def"
assert m.dict_get(d2, "abc") == 3
assert m.dict_get(d2, "def") is None
assert m.dict_get(d2, "abc", 5) == 3
assert m.dict_get(d2, "def", 5) == 5

assert doc(m.get_dict) == "get_dict() -> dict"
assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"

Expand Down