diff --git a/nullability/pointer_nullability_diagnosis.cc b/nullability/pointer_nullability_diagnosis.cc index da58f1b11..c17dcdfc2 100644 --- a/nullability/pointer_nullability_diagnosis.cc +++ b/nullability/pointer_nullability_diagnosis.cc @@ -177,6 +177,30 @@ static const Expr* absl_nullable matchesNonConstCallNullCheck( DynTypedNode::create(*ParentFunction), Ctx)); } +// Determines if the pointer expression appears inside a lambda via a lambda +// capture. Handles cases like: +// - `*p` and `p->x`, where `p` is a captured pointer variable +// - `*o.p`, where `o` is a captured object with a pointer member `p` +// Currently does not handle the cases: +// - `*this->p` and `*p`, where `p` is a pointer member of the enclosing class +// and `this` is captured in the lambda. +static bool isCapturedVariableOrMemberAccess(const Expr* absl_nonnull E) { + E = E->IgnoreParenImpCasts(); + + // Cases like `*p` and `p->x`, where `p` is captured in the lambda. + if (const DeclRefExpr* Variable = dyn_cast(E)) + return Variable->refersToEnclosingVariableOrCapture(); + + // Cases like `*o.p`, where `o` is an object captured in the lambda. + if (const MemberExpr* MemberAccess = dyn_cast(E)) { + const Expr* Base = MemberAccess->getBase(); + if (const DeclRefExpr* BaseVariable = dyn_cast(Base)) + return BaseVariable->refersToEnclosingVariableOrCapture(); + } + + return false; +} + // Diagnoses whether `E` violates the expectation that it is nonnull. static SmallVector diagnoseNonnullExpected( const Expr* absl_nonnull E, const Environment& Env, ASTContext& Ctx, @@ -207,8 +231,7 @@ static SmallVector diagnoseNonnullExpected( Range = CharSourceRange::getTokenRange(E->getSourceRange()); if (const Expr* NullCheck = - matchesNonConstCallNullCheck(*E, Ctx, Env.getCurrentFunc()); - NullCheck != nullptr) { + matchesNonConstCallNullCheck(*E, Ctx, Env.getCurrentFunc())) return {{ .Code = PointerNullabilityDiagnostic::ErrorCode:: ExpectedNonnullWithCheckOnNonConstCall, @@ -225,12 +248,27 @@ static SmallVector diagnoseNonnullExpected( "that. Or, mark the method as const (if possible, and if it has " "zero params).", }}; - } + + Range = getRangeModuloMacros(Range, Ctx); + + if (isCapturedVariableOrMemberAccess(E)) + return { + {.Code = PointerNullabilityDiagnostic::ErrorCode::ExpectedNonnull, + .Ctx = DiagCtx, + .Range = Range, + .Callee = Callee, + .ParamName = ParamName, + .NoteRange = Range, + .NoteMessage = + "This pointer is captured and dereferenced in a lambda. If it is " + "null-checked outside the lambda, consider capturing the pointee " + "by value or reference (possibly with an init-capture). Otherwise " + "do a null check inside the lambda body to ensure null safety."}}; return {{ .Code = PointerNullabilityDiagnostic::ErrorCode::ExpectedNonnull, .Ctx = DiagCtx, - .Range = getRangeModuloMacros(Range, Ctx), + .Range = Range, .Callee = Callee, .ParamName = ParamName, }};