Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
completeRelationshipType,
allLabelCompletions,
allReltypeCompletions,
completeNodeLabel,
} from './schemaBasedCompletions';
import { backtickIfNeeded, uniq } from './autocompletionHelpers';

Expand Down Expand Up @@ -702,7 +703,7 @@ export function completionCoreCompletion(
}

if (topExprParent === CypherParser.RULE_nodePattern) {
return allLabelCompletions(dbSchema);
return completeNodeLabel(dbSchema, parsingResult, symbolsInfo);
}

if (topExprParent === CypherParser.RULE_relationshipPattern) {
Expand Down
125 changes: 102 additions & 23 deletions packages/language-support/src/autocompletion/schemaBasedCompletions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import {
NodePatternContext,
PatternElementContext,
QuantifierContext,
RelationshipPatternContext,
} from '../generated-parser/CypherCmdParser';
import { ParserRuleContext } from 'antlr4';
import { backtickIfNeeded } from './autocompletionHelpers';
import { _internalFeatureFlags } from '../featureFlags';

Expand Down Expand Up @@ -52,45 +52,45 @@ export const allReltypeCompletions = (dbSchema: DbSchema) =>
reltypesToCompletions(dbSchema.relationshipTypes);

function intersectChildren(
relsFromLabels: Map<string, Set<string>>,
connectedLabels: Map<string, Set<string>>,
children: LabelOrCondition[],
): Set<string> {
let intersection: Set<string> = undefined;
children.forEach((c) => {
intersection = intersection
? (intersection = intersection.intersection(
walkLabelTree(relsFromLabels, c),
walkLabelTree(connectedLabels, c),
))
: walkLabelTree(relsFromLabels, c);
: walkLabelTree(connectedLabels, c);
});
return intersection ?? new Set();
}

function uniteChildren(
relsFromLabels: Map<string, Set<string>>,
connectedLabels: Map<string, Set<string>>,
children: LabelOrCondition[],
): Set<string> {
let union: Set<string> = new Set();
children.forEach(
(c) => (union = union.union(walkLabelTree(relsFromLabels, c))),
(c) => (union = union.union(walkLabelTree(connectedLabels, c))),
);
return union;
}

function walkLabelTree(
relsFromLabels: Map<string, Set<string>>,
connectedLabels: Map<string, Set<string>>,
labelTree: LabelOrCondition,
): Set<string> {
if (isLabelLeaf(labelTree)) {
return relsFromLabels.get(labelTree.value);
return connectedLabels.get(labelTree.value);
} else if (labelTree.andOr == 'and') {
return intersectChildren(relsFromLabels, labelTree.children);
return intersectChildren(connectedLabels, labelTree.children);
} else {
return uniteChildren(relsFromLabels, labelTree.children);
return uniteChildren(connectedLabels, labelTree.children);
}
}

function getRelsFromLabelsSet(dbSchema: DbSchema): Map<string, Set<string>> {
function getRelsFromNodesSet(dbSchema: DbSchema): Map<string, Set<string>> {
if (dbSchema.graphSchema) {
const relsFromLabelsSet: Map<string, Set<string>> = new Map();
dbSchema.graphSchema.forEach((rel) => {
Expand All @@ -112,35 +112,114 @@ function getRelsFromLabelsSet(dbSchema: DbSchema): Map<string, Set<string>> {
return undefined;
}

export function completeRelationshipType(
function getNodesFromRelsSet(dbSchema: DbSchema): Map<string, Set<string>> {
if (dbSchema.graphSchema) {
const nodesFromRelsSet: Map<string, Set<string>> = new Map();
dbSchema.graphSchema.forEach((rel) => {
if (!nodesFromRelsSet.has(rel.relType)) {
nodesFromRelsSet.set(rel.relType, new Set());
}
const currentRelEntry = nodesFromRelsSet.get(rel.relType);
currentRelEntry.add(rel.to);
currentRelEntry.add(rel.from);
});
return nodesFromRelsSet;
}
return undefined;
}

export function completeNodeLabel(
dbSchema: DbSchema,
parsingResult: ParsedStatement,
symbolsInfo: SymbolsInfo,
): CompletionItem[] {
if (!_internalFeatureFlags.schemaBasedPatternCompletions) {
return allReltypeCompletions(dbSchema);
if (
!_internalFeatureFlags.schemaBasedPatternCompletions ||
dbSchema.graphSchema === undefined
) {
return allLabelCompletions(dbSchema);
}

if (dbSchema.graphSchema === undefined) {
return allReltypeCompletions(dbSchema);
}

// limitation: not checking PathPatternNonEmptyContext
// limitation: not handling parenthesized paths
const callContext = findParent(
parsingResult.stopNode.parentCtx,
(x) => x instanceof PatternElementContext,
);

if (callContext instanceof PatternElementContext) {
const lastValidElement = callContext.children.toReversed().find((child) => {
if (child instanceof ParserRuleContext) {
if (child instanceof RelationshipPatternContext) {
if (child.exception === null) {
return true;
}
}
});

// limitation: bailing out on quantifiers
if (lastValidElement instanceof QuantifierContext) {
return allLabelCompletions(dbSchema);
}

if (lastValidElement instanceof RelationshipPatternContext) {
// limitation: not checking anonymous variables
const variable = lastValidElement.variable();
if (variable === null) {
return allLabelCompletions(dbSchema);
}

const foundVariable = symbolsInfo?.symbolTables
?.flat()
.find((entry) => entry.references.includes(variable.start.start));

if (
foundVariable === undefined ||
('children' in foundVariable.labels &&
foundVariable.labels.children.length == 0)
) {
return allLabelCompletions(dbSchema);
}

// limitation: not direction-aware (ignores <- vs ->)
// limitation: not checking node label repetition
const nodesFromRelsSet = getNodesFromRelsSet(dbSchema);
const rels = walkLabelTree(nodesFromRelsSet, foundVariable.labels);

return labelsToCompletions(Array.from(rels));
}
}

return allLabelCompletions(dbSchema);
}

export function completeRelationshipType(
dbSchema: DbSchema,
parsingResult: ParsedStatement,
symbolsInfo: SymbolsInfo,
): CompletionItem[] {
if (
!_internalFeatureFlags.schemaBasedPatternCompletions ||
dbSchema.graphSchema === undefined
) {
return allReltypeCompletions(dbSchema);
}

// limitation: not checking PathPatternNonEmptyContext
// limitation: not handling parenthesized paths
const patternContext = findParent(
parsingResult.stopNode.parentCtx,
(x) => x instanceof PatternElementContext,
);

if (patternContext instanceof PatternElementContext) {
const lastValidElement = patternContext.children
.toReversed()
.find((child) => {
if (child instanceof NodePatternContext) {
if (child.exception === null) {
return true;
}
}
});

// limitation: bailing out on quantifiers
if (lastValidElement instanceof QuantifierContext) {
return allReltypeCompletions(dbSchema);
Expand All @@ -166,8 +245,8 @@ export function completeRelationshipType(
}

// limitation: not direction-aware (ignores <- vs ->)
// limitation: not checking relationship variable reuse
const relsFromLabelsSet = getRelsFromLabelsSet(dbSchema);
// limitation: not checking relationship type repetition
const relsFromLabelsSet = getRelsFromNodesSet(dbSchema);
const rels = walkLabelTree(relsFromLabelsSet, foundVariable.labels);

return reltypesToCompletions(Array.from(rels));
Expand Down
Loading