Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
#define PYBIND11_BOOL_ATTR "__bool__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
#define PYBIND11_BUILTINS_MODULE "builtins"
#define PYBIND11_DICT_GET_ITEM_WITH_ERROR PyDict_GetItemWithError
// Providing a separate declaration to make Clang's -Wmissing-prototypes happy.
// See comment for PYBIND11_MODULE below for why this is marked "maybe unused".
#define PYBIND11_PLUGIN_IMPL(name) \
Expand Down Expand Up @@ -213,6 +214,7 @@
#define PYBIND11_BOOL_ATTR "__nonzero__"
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
#define PYBIND11_BUILTINS_MODULE "__builtin__"
#define PYBIND11_DICT_GET_ITEM_WITH_ERROR _PyDict_GetItemWithError
// Providing a separate PyInit decl to make Clang's -Wmissing-prototypes happy.
// See comment for PYBIND11_MODULE below for why this is marked "maybe unused".
#define PYBIND11_PLUGIN_IMPL(name) \
Expand Down
46 changes: 46 additions & 0 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ namespace accessor_policies {
struct sequence_item;
struct list_item;
struct tuple_item;
struct dict_item;
} // namespace accessor_policies
using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
using str_attr_accessor = accessor<accessor_policies::str_attr>;
using item_accessor = accessor<accessor_policies::generic_item>;
using sequence_accessor = accessor<accessor_policies::sequence_item>;
using list_accessor = accessor<accessor_policies::list_item>;
using tuple_accessor = accessor<accessor_policies::tuple_item>;
using dict_accessor = accessor<accessor_policies::dict_item>;

/// Tag and check to identify a class which implements the Python object API
class pyobject_tag { };
Expand Down Expand Up @@ -613,6 +615,30 @@ struct tuple_item {
}
}
};

struct dict_item {
using key_type = object;

static object get(handle obj, handle key) {
if (PyObject *result = PYBIND11_DICT_GET_ITEM_WITH_ERROR(obj.ptr(), key.ptr())) {
return reinterpret_borrow<object>(result);
} else {
// NULL with an exception means exception occurred when calling
// "__hash__" or "__eq__" on the key
// NULL without an exception means the key wasn’t present
if (!PyErr_Occurred())
// Synthesize a KeyError with the key
PyErr_SetObject(PyExc_KeyError, key.inc_ref().ptr());
throw error_already_set();
}
}

static void set(handle obj, handle key, handle val) {
if (PyDict_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) {
throw error_already_set();
}
}
};
PYBIND11_NAMESPACE_END(accessor_policies)

/// STL iterator template used for tuple, list, sequence and dict
Expand Down Expand Up @@ -1285,13 +1311,33 @@ class dict : public object {

size_t size() const { return (size_t) PyDict_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::dict_accessor operator[](const char *key) const { return {*this, pybind11::str(key)}; }
detail::dict_accessor operator[](handle h) const { return {*this, reinterpret_borrow<object>(h)}; }
detail::dict_iterator begin() const { return {*this, 0}; }
detail::dict_iterator end() const { return {}; }
void clear() const { PyDict_Clear(ptr()); }
template <typename T> bool contains(T &&key) const {
return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr()) == 1;
}

object get(handle key, handle default_ = none()) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other reviewers: should this be moved out of the class definition? I think it's on the edge of being too long, but good enough and maybe nicer to keep it just as-is than to split it up into two parts.

if (PyObject *result = PYBIND11_DICT_GET_ITEM_WITH_ERROR(m_ptr, key.ptr())) {
return reinterpret_borrow<object>(result);
} else {
// NULL with an exception means exception occurred when calling
// "__hash__" or "__eq__" on the key
if (PyErr_Occurred())
throw error_already_set();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw or return default?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// NULL without an exception means the key wasn’t present
else
return reinterpret_borrow<object>(default_);
}
}

object get(const char *key, handle default_ = none()) const {
return get(pybind11::str(key), default_);
}

private:
/// Call the `dict` Python type -- always returns a new reference
static PyObject *raw_dict(PyObject *op) {
Expand Down
16 changes: 14 additions & 2 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,24 @@ TEST_SUBMODULE(pytypes, m) {
auto d2 = py::dict("z"_a=3, **d1);
return d2;
});
m.def("dict_contains", [](py::dict dict, py::object val) {
m.def("dict_contains", [](py::dict dict, const char* val) {
return dict.contains(val);
});
m.def("dict_contains", [](py::dict dict, const char* val) {
m.def("dict_contains", [](py::dict dict, py::object val) {
return dict.contains(val);
});
m.def("dict_get", [](py::dict dict, const char* key, py::object default_) {
return dict.get(key, default_);
});
m.def("dict_get", [](py::dict dict, py::object key, py::object default_) {
return dict.get(key, default_);
});
m.def("dict_get", [](py::dict dict, const char* key) {
return dict.get(key);
});
m.def("dict_get", [](py::dict dict, py::object key) {
return dict.get(key);
});

// test_str
m.def("str_from_string", []() { return py::str(std::string("baz")); });
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ def test_dict(capture, doc):
assert m.dict_contains({42: None}, 42)
assert m.dict_contains({"foo": None}, "foo")

d2 = {True: 42, 3: "abc", "abc": 3}
assert m.dict_get(d2, True) == 42
assert m.dict_get(d2, False) is None
assert m.dict_get(d2, 3, "def") == "abc"
assert m.dict_get(d2, 5, "def") == "def"
assert m.dict_get(d2, "abc") == 3
assert m.dict_get(d2, "def") is None
assert m.dict_get(d2, "abc", 5) == 3
assert m.dict_get(d2, "def", 5) == 5

assert doc(m.get_dict) == "get_dict() -> dict"
assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"

Expand Down