Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions crates/ty_python_semantic/src/types/cyclic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ use std::cmp::Eq;
use std::hash::Hash;
use std::marker::PhantomData;

use rustc_hash::FxHashMap;
use rustc_hash::{FxHashMap, FxHashSet};

use crate::FxIndexSet;
use crate::types::Type;

/// Maximum recursion depth for cycle detection.
Expand Down Expand Up @@ -64,7 +63,7 @@ pub struct CycleDetector<Tag, T, R> {
/// If the type we're visiting is present in `seen`, it indicates that we've hit a cycle (due
/// to a recursive type); we need to immediately short circuit the whole operation and return
/// the fallback value. That's why we pop items off the end of `seen` after we've visited them.
seen: RefCell<FxIndexSet<T>>,
seen: RefCell<FxHashSet<T>>,

/// Unlike `seen`, this field is a pure performance optimisation (and an essential one). If the
/// type we're trying to normalize is present in `cache`, it doesn't necessarily mean we've hit
Expand All @@ -86,7 +85,7 @@ pub struct CycleDetector<Tag, T, R> {
impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
pub fn new(fallback: R) -> Self {
CycleDetector {
seen: RefCell::new(FxIndexSet::default()),
seen: RefCell::new(FxHashSet::default()),
cache: RefCell::new(FxHashMap::default()),
depth: Cell::new(0),
fallback,
Expand All @@ -99,24 +98,23 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
return val.clone();
}

// We hit a cycle
if !self.seen.borrow_mut().insert(item.clone()) {
return self.fallback.clone();
}

// Check depth limit to prevent stack overflow from recursive generic types
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
self.seen.borrow_mut().pop();
return self.fallback.clone();
}
self.depth.set(current_depth + 1);

// We hit a cycle
if !self.seen.borrow_mut().insert(item.clone()) {
return self.fallback.clone();
}

self.depth.set(current_depth + 1);
let ret = func();

self.depth.set(current_depth);
self.seen.borrow_mut().pop();
self.seen.borrow_mut().remove(&item);
self.cache.borrow_mut().insert(item, ret.clone());

ret
Expand All @@ -127,24 +125,24 @@ impl<Tag, T: Hash + Eq + Clone, R: Clone> CycleDetector<Tag, T, R> {
return Some(val.clone());
}

// We hit a cycle
if !self.seen.borrow_mut().insert(item.clone()) {
return Some(self.fallback.clone());
}

// Check depth limit to prevent stack overflow from recursive generic protocols
// with growing specializations (e.g., C[set[T]] -> C[set[set[T]]] -> ...)
let current_depth = self.depth.get();
if current_depth >= MAX_RECURSION_DEPTH {
self.seen.borrow_mut().pop();
return Some(self.fallback.clone());
}

// We hit a cycle
if !self.seen.borrow_mut().insert(item.clone()) {
return Some(self.fallback.clone());
}

self.depth.set(current_depth + 1);

let ret = func()?;

self.depth.set(current_depth);
self.seen.borrow_mut().pop();
self.seen.borrow_mut().remove(&item);
self.cache.borrow_mut().insert(item, ret.clone());

Some(ret)
Expand Down