Skip to content

Commit 08b0bda

Browse files
superbobrywjakob
authored andcommitted
Added set::contains and generalized dict::contains (pybind#1884)
Dynamically resolving __contains__ on each call is wasteful since set has a public PySet_Contains function.
1 parent 5b0ea77 commit 08b0bda

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

include/pybind11/pytypes.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,8 +1224,9 @@ class dict : public object {
12241224
detail::dict_iterator begin() const { return {*this, 0}; }
12251225
detail::dict_iterator end() const { return {}; }
12261226
void clear() const { PyDict_Clear(ptr()); }
1227-
bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; }
1228-
bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; }
1227+
template <typename T> bool contains(T &&key) const {
1228+
return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr()) == 1;
1229+
}
12291230

12301231
private:
12311232
/// Call the `dict` Python type -- always returns a new reference
@@ -1276,6 +1277,9 @@ class set : public object {
12761277
return PySet_Add(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 0;
12771278
}
12781279
void clear() const { PySet_Clear(m_ptr); }
1280+
template <typename T> bool contains(T &&val) const {
1281+
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
1282+
}
12791283
};
12801284

12811285
class function : public object {

tests/test_pytypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ TEST_SUBMODULE(pytypes, m) {
3737
for (auto item : set)
3838
py::print("key:", item);
3939
});
40+
m.def("set_contains", [](py::set set, py::object key) {
41+
return set.contains(key);
42+
});
43+
m.def("set_contains", [](py::set set, const char* key) {
44+
return set.contains(key);
45+
});
4046

4147
// test_dict
4248
m.def("get_dict", []() { return py::dict("key"_a="value"); });
@@ -49,6 +55,12 @@ TEST_SUBMODULE(pytypes, m) {
4955
auto d2 = py::dict("z"_a=3, **d1);
5056
return d2;
5157
});
58+
m.def("dict_contains", [](py::dict dict, py::object val) {
59+
return dict.contains(val);
60+
});
61+
m.def("dict_contains", [](py::dict dict, const char* val) {
62+
return dict.contains(val);
63+
});
5264

5365
// test_str
5466
m.def("str_from_string", []() { return py::str(std::string("baz")); });

tests/test_pytypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def test_set(capture, doc):
3737
key: key4
3838
"""
3939

40+
assert not m.set_contains(set([]), 42)
41+
assert m.set_contains({42}, 42)
42+
assert m.set_contains({"foo"}, "foo")
43+
4044
assert doc(m.get_list) == "get_list() -> list"
4145
assert doc(m.print_list) == "print_list(arg0: list) -> None"
4246

@@ -53,6 +57,10 @@ def test_dict(capture, doc):
5357
key: key2, value=value2
5458
"""
5559

60+
assert not m.dict_contains({}, 42)
61+
assert m.dict_contains({42: None}, 42)
62+
assert m.dict_contains({"foo": None}, "foo")
63+
5664
assert doc(m.get_dict) == "get_dict() -> dict"
5765
assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"
5866

0 commit comments

Comments
 (0)