Skip to content
Merged
20 changes: 20 additions & 0 deletions Lib/test/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,26 @@ def __hash__(self):
with self.assertRaises(KeyError):
d.get(key2)

def test_clear_at_lookup(self):
class X:
def __hash__(self):
return 1
def __eq__(self, other):
nonlocal d
d.clear()

d = {}
for _ in range(10):
d[X()] = None

self.assertEqual(len(d), 1)

d = {}
for _ in range(10):
d.setdefault(X(), None)

self.assertEqual(len(d), 1)


class CAPITest(unittest.TestCase):

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed crash in :class:`dict` if :meth:`dict.clear` is called at the lookup
stage. Patch by Mikhail Efimov and Inada Naoki.
82 changes: 32 additions & 50 deletions Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,14 @@ static inline int
insert_combined_dict(PyInterpreterState *interp, PyDictObject *mp,
Py_hash_t hash, PyObject *key, PyObject *value)
{
// gh-140551: If dict was cleared in _Py_dict_lookup,
// we have to resize one more time to force general key kind.
if (DK_IS_UNICODE(mp->ma_keys) && !PyUnicode_CheckExact(key)) {
if (insertion_resize(mp, 0) < 0)
return -1;
assert(mp->ma_keys->dk_kind == DICT_KEYS_GENERAL);
}

if (mp->ma_keys->dk_usable <= 0) {
/* Need to resize. */
if (insertion_resize(mp, 1) < 0) {
Expand Down Expand Up @@ -1871,38 +1879,31 @@ insertdict(PyInterpreterState *interp, PyDictObject *mp,
PyObject *key, Py_hash_t hash, PyObject *value)
{
PyObject *old_value;
Py_ssize_t ix;

ASSERT_DICT_LOCKED(mp);

if (DK_IS_UNICODE(mp->ma_keys) && !PyUnicode_CheckExact(key)) {
if (insertion_resize(mp, 0) < 0)
goto Fail;
assert(mp->ma_keys->dk_kind == DICT_KEYS_GENERAL);
}

if (_PyDict_HasSplitTable(mp)) {
Py_ssize_t ix = insert_split_key(mp->ma_keys, key, hash);
if (_PyDict_HasSplitTable(mp) && PyUnicode_CheckExact(key)) {
ix = insert_split_key(mp->ma_keys, key, hash);
if (ix != DKIX_EMPTY) {
insert_split_value(interp, mp, key, value, ix);
Py_DECREF(key);
Py_DECREF(value);
return 0;
}

/* No space in shared keys. Resize and continue below. */
if (insertion_resize(mp, 1) < 0) {
// No space in shared keys. Go to insert_combined_dict() below.
}
else {
ix = _Py_dict_lookup(mp, key, hash, &old_value);
if (ix == DKIX_ERROR)
goto Fail;
}
}

Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &old_value);
if (ix == DKIX_ERROR)
goto Fail;

if (ix == DKIX_EMPTY) {
assert(!_PyDict_HasSplitTable(mp));
/* Insert into new slot. */
assert(old_value == NULL);
// insert_combined_dict() will convert from non DICT_KEYS_GENERAL table
// into DICT_KEYS_GENERAL table if key is not Unicode.
// We don't convert it before _Py_dict_lookup because non-Unicode key
// may change generic table into Unicode table.
if (insert_combined_dict(interp, mp, hash, key, value) < 0) {
goto Fail;
}
Expand Down Expand Up @@ -4374,6 +4375,7 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
PyDictObject *mp = (PyDictObject *)d;
PyObject *value;
Py_hash_t hash;
Py_ssize_t ix;
PyInterpreterState *interp = _PyInterpreterState_GET();

ASSERT_DICT_LOCKED(d);
Expand Down Expand Up @@ -4409,17 +4411,8 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
return 0;
}

if (!PyUnicode_CheckExact(key) && DK_IS_UNICODE(mp->ma_keys)) {
if (insertion_resize(mp, 0) < 0) {
if (result) {
*result = NULL;
}
return -1;
}
}

if (_PyDict_HasSplitTable(mp)) {
Py_ssize_t ix = insert_split_key(mp->ma_keys, key, hash);
if (_PyDict_HasSplitTable(mp) && PyUnicode_CheckExact(key)) {
ix = insert_split_key(mp->ma_keys, key, hash);
if (ix != DKIX_EMPTY) {
PyObject *value = mp->ma_values->values[ix];
int already_present = value != NULL;
Expand All @@ -4432,27 +4425,22 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
}
return already_present;
}

/* No space in shared keys. Resize and continue below. */
if (insertion_resize(mp, 1) < 0) {
goto error;
}
// No space in shared keys. Go to insert_combined_dict() below.
}

assert(!_PyDict_HasSplitTable(mp));

Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &value);
if (ix == DKIX_ERROR) {
if (result) {
*result = NULL;
else {
ix = _Py_dict_lookup(mp, key, hash, &value);
if (ix == DKIX_ERROR) {
if (result) {
*result = NULL;
}
return -1;
}
return -1;
}

if (ix == DKIX_EMPTY) {
assert(!_PyDict_HasSplitTable(mp));
value = default_value;

// See comment to this function in insertdict.
if (insert_combined_dict(interp, mp, hash, Py_NewRef(key), Py_NewRef(value)) < 0) {
Py_DECREF(key);
Py_DECREF(value);
Expand All @@ -4477,12 +4465,6 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu
*result = incref_result ? Py_NewRef(value) : value;
}
return 1;

error:
if (result) {
*result = NULL;
}
return -1;
}

int
Expand Down
Loading