diff --git a/packages/zenscript/src/builtins/bool.dzs b/packages/zenscript/src/builtins/bool.dzs index 1df8d3f5..155ece65 100644 --- a/packages/zenscript/src/builtins/bool.dzs +++ b/packages/zenscript/src/builtins/bool.dzs @@ -1 +1,3 @@ -zenClass bool +zenClass bool { + operator as() as string +} diff --git a/packages/zenscript/src/builtins/byte.dzs b/packages/zenscript/src/builtins/byte.dzs index b88cbffc..97f515ac 100644 --- a/packages/zenscript/src/builtins/byte.dzs +++ b/packages/zenscript/src/builtins/byte.dzs @@ -1 +1,3 @@ -zenClass byte +zenClass byte { + operator as() as short, int, long, float, double, string +} diff --git a/packages/zenscript/src/builtins/double.dzs b/packages/zenscript/src/builtins/double.dzs index 39ec4f23..8396e6d6 100644 --- a/packages/zenscript/src/builtins/double.dzs +++ b/packages/zenscript/src/builtins/double.dzs @@ -1 +1,3 @@ -zenClass double +zenClass double { + operator as() as byte, short, int, long, float, string +} diff --git a/packages/zenscript/src/builtins/float.dzs b/packages/zenscript/src/builtins/float.dzs index 873011b0..9b1e0eaf 100644 --- a/packages/zenscript/src/builtins/float.dzs +++ b/packages/zenscript/src/builtins/float.dzs @@ -1 +1,3 @@ -zenClass float +zenClass float { + operator as() as byte, short, int, long, double, string +} diff --git a/packages/zenscript/src/builtins/int.dzs b/packages/zenscript/src/builtins/int.dzs index 700903b9..9099b306 100644 --- a/packages/zenscript/src/builtins/int.dzs +++ b/packages/zenscript/src/builtins/int.dzs @@ -1 +1,3 @@ -zenClass int +zenClass int { + operator as() as byte, short, long, float, double, string +} diff --git a/packages/zenscript/src/builtins/long.dzs b/packages/zenscript/src/builtins/long.dzs index 00c175ff..599202e9 100644 --- a/packages/zenscript/src/builtins/long.dzs +++ b/packages/zenscript/src/builtins/long.dzs @@ -1 +1,3 @@ -zenClass long +zenClass long { + operator as() as byte, short, int, float, double, string +} diff --git a/packages/zenscript/src/builtins/short.dzs b/packages/zenscript/src/builtins/short.dzs index de4b3b14..3d41b892 100644 --- a/packages/zenscript/src/builtins/short.dzs +++ b/packages/zenscript/src/builtins/short.dzs @@ -1 +1,3 @@ -zenClass short +zenClass short { + operator as() as byte, int, long, float, double, string +} diff --git a/packages/zenscript/src/builtins/string.dzs b/packages/zenscript/src/builtins/string.dzs index 99675566..42c6ade3 100644 --- a/packages/zenscript/src/builtins/string.dzs +++ b/packages/zenscript/src/builtins/string.dzs @@ -84,4 +84,6 @@ zenClass string { function lastIndexOf(needle as string, fromIndex as int) as int; function split(separator as string, maximum as int = 0) as [string]; + + operator as() as bool, byte, int, short, long, float, double } diff --git a/packages/zenscript/src/lsp/semantic-token-provider.ts b/packages/zenscript/src/lsp/semantic-token-provider.ts index 39858ec8..2de4e9b5 100644 --- a/packages/zenscript/src/lsp/semantic-token-provider.ts +++ b/packages/zenscript/src/lsp/semantic-token-provider.ts @@ -1,10 +1,14 @@ +import type { AstNode, LangiumDocument } from 'langium' import type { SemanticTokenAcceptor } from 'langium/lsp' +import type { CancellationToken } from 'vscode-languageserver' import type { ZenScriptAstType } from '../generated/ast' import type { ZenScriptServices } from '../module' + import type { TypeComputer } from '../typing/type-computer' -import { type AstNode, CstUtils, stream } from 'langium' + +import { AstUtils, CstUtils, interruptAndCheck, stream } from 'langium' import { AbstractSemanticTokenProvider } from 'langium/lsp' -import { SemanticTokenModifiers, SemanticTokenTypes } from 'vscode-languageserver' +import { LSPErrorCodes, ResponseError, SemanticTokenModifiers, SemanticTokenTypes } from 'vscode-languageserver' import { isLocation } from '../generated/ast' import { isStringType } from '../typing/type-description' @@ -21,7 +25,27 @@ export class ZenScriptSemanticTokenProvider extends AbstractSemanticTokenProvide this.typeComputer = services.typing.TypeComputer } - override highlightElement(node: AstNode, acceptor: SemanticTokenAcceptor): void { + protected async computeHighlighting(document: LangiumDocument, acceptor: SemanticTokenAcceptor, cancelToken: CancellationToken): Promise { + const root = document.parseResult.value + const treeIterator = AstUtils.streamAst(root, { range: this.currentRange }).iterator() + let result: IteratorResult + do { + result = treeIterator.next() + if (!result.done) { + const prevState = document.state + await interruptAndCheck(cancelToken) + if (prevState > document.state) { + throw new ResponseError(LSPErrorCodes.ContentModified, 'Document was modified during semantic token computation') + } + const node = result.value + if (this.highlightElement(node, acceptor) === 'prune') { + treeIterator.prune() + } + } + } while (!result.done) + } + + override highlightElement(node: AstNode, acceptor: SemanticTokenAcceptor): void | 'prune' { // @ts-expect-error allowed index type this.rules[node.$type]?.call(this, node, acceptor) } diff --git a/packages/zenscript/src/module.ts b/packages/zenscript/src/module.ts index bbb5cb14..f62de12d 100644 --- a/packages/zenscript/src/module.ts +++ b/packages/zenscript/src/module.ts @@ -13,6 +13,7 @@ import { ZenScriptMemberProvider } from './reference/member-provider' import { ZenScriptNameProvider } from './reference/name-provider' import { ZenScriptScopeComputation } from './reference/scope-computation' import { ZenScriptScopeProvider } from './reference/scope-provider' +import { ZenScriptOverloadResolver } from './typing/overload-resolver' import { ZenScriptTypeComputer } from './typing/type-computer' import { registerValidationChecks, ZenScriptValidator } from './validation/validator' import { ZenScriptBracketManager } from './workspace/bracket-manager' @@ -34,6 +35,7 @@ export interface ZenScriptAddedServices { } typing: { TypeComputer: ZenScriptTypeComputer + OverloadResolver: ZenScriptOverloadResolver } workspace: { PackageManager: ZenScriptPackageManager @@ -84,6 +86,7 @@ export const ZenScriptModule: Module new ZenScriptTypeComputer(services), + OverloadResolver: services => new ZenScriptOverloadResolver(services), }, lsp: { CompletionProvider: services => new ZenScriptCompletionProvider(services), diff --git a/packages/zenscript/src/reference/dynamic-provider.ts b/packages/zenscript/src/reference/dynamic-provider.ts index 613c93f4..ca03e8c5 100644 --- a/packages/zenscript/src/reference/dynamic-provider.ts +++ b/packages/zenscript/src/reference/dynamic-provider.ts @@ -1,6 +1,7 @@ import type { AstNode, AstNodeDescription } from 'langium' import type { ZenScriptAstType } from '../generated/ast' import type { ZenScriptServices } from '../module' +import type { ZenScriptOverloadResolver } from '../typing/overload-resolver' import type { TypeComputer } from '../typing/type-computer' import type { DescriptionIndex } from '../workspace/description-index' import type { MemberProvider } from './member-provider' @@ -19,11 +20,13 @@ export class ZenScriptDynamicProvider implements DynamicProvider { private readonly descriptionIndex: DescriptionIndex private readonly typeComputer: TypeComputer private readonly memberProvider: MemberProvider + private readonly overloadResolver: ZenScriptOverloadResolver constructor(services: ZenScriptServices) { this.descriptionIndex = services.workspace.DescriptionIndex this.typeComputer = services.typing.TypeComputer this.memberProvider = services.references.MemberProvider + this.overloadResolver = services.typing.OverloadResolver } getDynamics(source: AstNode): AstNodeDescription[] { @@ -44,9 +47,12 @@ export class ZenScriptDynamicProvider implements DynamicProvider { // dynamic arguments if (isCallExpression(source.$container) && source.$containerProperty === 'arguments') { const index = source.$containerIndex! - const receiverType = this.typeComputer.inferType(source.$container.receiver) - if (isFunctionType(receiverType)) { - const paramType = receiverType.paramTypes[index] + + // prevent circular ref resolve, manually infer type + + const callType = this.overloadResolver.predictCallType(source.$container) + if (isFunctionType(callType)) { + const paramType = callType.paramTypes[index] if (isClassType(paramType)) { stream(this.memberProvider.getMembers(paramType.declaration)) .map(it => it.node) diff --git a/packages/zenscript/src/reference/name-provider.ts b/packages/zenscript/src/reference/name-provider.ts index 92a480c6..6656491f 100644 --- a/packages/zenscript/src/reference/name-provider.ts +++ b/packages/zenscript/src/reference/name-provider.ts @@ -47,13 +47,13 @@ export class ZenScriptNameProvider extends DefaultNameProvider { Script: source => source.$document ? getName(source.$document) : undefined, ImportDeclaration: source => source.alias || source.path.at(-1)?.$refText, FunctionDeclaration: source => source.name || 'lambda function', - ConstructorDeclaration: _ => 'zenConstructor', + ConstructorDeclaration: source => isClassDeclaration(source.$container) ? source.$container.name : 'constructor', OperatorFunctionDeclaration: source => source.op, } private readonly nameNodeRules: NameNodeRuleMap = { ImportDeclaration: source => GrammarUtils.findNodeForProperty(source.$cstNode, 'alias'), - ConstructorDeclaration: source => GrammarUtils.findNodeForProperty(source.$cstNode, 'zenConstructor'), + ConstructorDeclaration: source => GrammarUtils.findNodeForKeyword(source.$cstNode, 'zenConstructor'), OperatorFunctionDeclaration: source => GrammarUtils.findNodeForProperty(source.$cstNode, 'op'), } } diff --git a/packages/zenscript/src/reference/scope-provider.ts b/packages/zenscript/src/reference/scope-provider.ts index fc7d445a..42e5020d 100644 --- a/packages/zenscript/src/reference/scope-provider.ts +++ b/packages/zenscript/src/reference/scope-provider.ts @@ -1,13 +1,14 @@ import type { AstNode, AstNodeDescription, ReferenceInfo, Scope } from 'langium' -import type { ZenScriptAstType } from '../generated/ast' +import type { CallExpression, ZenScriptAstType } from '../generated/ast' import type { ZenScriptServices } from '../module' +import type { ZenScriptOverloadResolver } from '../typing/overload-resolver' import type { ZenScriptDescriptionIndex } from '../workspace/description-index' import type { PackageManager } from '../workspace/package-manager' import type { DynamicProvider } from './dynamic-provider' import type { MemberProvider } from './member-provider' import { substringBeforeLast } from '@intellizen/shared' import { AstUtils, DefaultScopeProvider, EMPTY_SCOPE, stream } from 'langium' -import { ClassDeclaration, ImportDeclaration, isClassDeclaration, TypeParameter } from '../generated/ast' +import { ClassDeclaration, ImportDeclaration, isCallExpression, isClassDeclaration, isScript, TypeParameter } from '../generated/ast' import { getPathAsString } from '../utils/ast' import { generateStream } from '../utils/stream' @@ -19,6 +20,7 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { private readonly memberProvider: MemberProvider private readonly dynamicProvider: DynamicProvider private readonly descriptionIndex: ZenScriptDescriptionIndex + private readonly overloadResolver: ZenScriptOverloadResolver constructor(services: ZenScriptServices) { super(services) @@ -26,6 +28,7 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { this.memberProvider = services.references.MemberProvider this.dynamicProvider = services.references.DynamicProvider this.descriptionIndex = services.workspace.DescriptionIndex + this.overloadResolver = services.typing.OverloadResolver } override getScope(context: ReferenceInfo): Scope { @@ -46,6 +49,26 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { .reduce((outer, descriptions) => this.createScope(descriptions, outer), outside as Scope) } + private importScope(source: ReferenceInfo, outer?: Scope) { + const script = AstUtils.findRootNode(source.container) + if (!isScript(script)) { + return EMPTY_SCOPE + } + + const imports = stream(script.imports) + .flatMap(it => this.descriptionIndex.createImportedDescription(it)) + + if (source.reference.$refText === '' || !isCallExpression(source.container.$container) || source.container.$containerProperty !== 'receiver') { + return this.createScope(imports, outer) + } + + const overload = this.overloadResolver.findOverloadMethod(imports, source.container.$container, source.reference.$refText) + if (!overload) { + return outer || EMPTY_SCOPE + } + return this.createScope([overload], outer) + } + private dynamicScope(astNode: AstNode, outside?: Scope) { return this.createScope(this.dynamicProvider.getDynamics(astNode), outside) } @@ -97,12 +120,22 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { outer = this.globalScope(outer) outer = this.dynamicScope(source.container, outer) + outer = this.importScope(source, outer) + + const processOverload = source.reference.$refText !== '' && isCallExpression(source.container.$container) && source.container.$containerProperty === 'receiver' + const processor = (desc: AstNodeDescription) => { switch (desc.type) { case TypeParameter: return case ImportDeclaration: { - return this.descriptionIndex.createImportedDescription(desc.node as ImportDeclaration) + return + } + case ClassDeclaration: { + if (processOverload) { + return this.overloadResolver.findOverlaodConstructor(desc.node as ClassDeclaration, source.container.$container as CallExpression) + } + return desc } default: return desc @@ -114,7 +147,16 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { MemberAccess: (source) => { const outer = this.dynamicScope(source.container) const members = this.memberProvider.getMembers(source.container.receiver) - return this.createScope(members, outer) + + if (source.reference.$refText === '' || !isCallExpression(source.container.$container) || source.container.$containerProperty !== 'receiver') { + return this.createScope(members, outer) + } + + const overload = this.overloadResolver.findOverloadMethod(members, source.container.$container, source.reference.$refText) + if (!overload) { + return outer + } + return this.createScope([overload], outer) }, NamedTypeReference: (source) => { @@ -126,7 +168,7 @@ export class ZenScriptScopeProvider extends DefaultScopeProvider { case ClassDeclaration: return desc case ImportDeclaration: { - return this.descriptionIndex.createImportedDescription(desc.node as ImportDeclaration) + return this.descriptionIndex.createImportedDescription(desc.node as ImportDeclaration).at(0) || desc } } } diff --git a/packages/zenscript/src/typing/overload-resolver.ts b/packages/zenscript/src/typing/overload-resolver.ts new file mode 100644 index 00000000..3737b013 --- /dev/null +++ b/packages/zenscript/src/typing/overload-resolver.ts @@ -0,0 +1,440 @@ +import type { AstNodeDescription } from 'langium' +import type { CallExpression, ClassDeclaration, ConstructorDeclaration, ExpandFunctionDeclaration, Expression, FunctionDeclaration } from '../generated/ast' +import type { ZenScriptServices } from '../module' +import type { ZenScriptMemberProvider } from '../reference/member-provider' +import type { ClassType, IntersectionType, Type, UnionType } from '../typing/type-description' +import type { ZenScriptDescriptionIndex } from '../workspace/description-index' +import type { ZenScriptTypeComputer } from './type-computer' +import { AstUtils, stream } from 'langium' +import { isClassDeclaration, isConstructorDeclaration, isExpandFunctionDeclaration, isFunctionDeclaration, isFunctionExpression, isMemberAccess, isOperatorFunctionDeclaration, isReferenceExpression, isScript } from '../generated/ast' +import { FunctionType, isClassType, isCompoundType, isTypeVariable } from '../typing/type-description' +import { getClassChain } from '../utils/ast' + +export enum OverloadMatch { + FullMatch = 0, + OptionalMatch = 1, + ImplicitMatch = 2, + PossibleMatch = 3, + NotMatch = 4, +} + +export type CallableDeclaration = ConstructorDeclaration | FunctionDeclaration | ExpandFunctionDeclaration + +export function isOptionalArgs(method: CallableDeclaration, before: number): boolean { + let index = before + if (index < 0) { + index = method.parameters.length + index + } + if (isExpandFunctionDeclaration(method)) { + index++ + } + + return index < method.parameters.length && method.parameters[index].defaultValue !== undefined +} + +export function isVarargs(method: CallableDeclaration): boolean { + return method.parameters.length > 0 && method.parameters[method.parameters.length - 1].varargs +} + +export class ZenScriptOverloadResolver { + private readonly typeComputer: ZenScriptTypeComputer + private readonly memberProvider: ZenScriptMemberProvider + private readonly descriptionIndex: ZenScriptDescriptionIndex + + constructor(services: ZenScriptServices) { + this.typeComputer = services.typing.TypeComputer + this.memberProvider = services.references.MemberProvider + this.descriptionIndex = services.workspace.DescriptionIndex + } + + findOverlaodConstructor(classDecl: ClassDeclaration, callExpr: CallExpression): AstNodeDescription | undefined { + const constructors = classDecl.members + .filter(it => isConstructorDeclaration(it)) + + if (constructors.length === 0) { + return + } + + const index = this.resolveOverload(constructors, it => it as ConstructorDeclaration, callExpr.arguments) + + if (index.length === 0) { + return this.descriptionIndex.getDescription(constructors[0]) + } + return this.descriptionIndex.getDescription(constructors[index[0]]) + } + + findOverloadMethod(members: Iterable, callExpr: CallExpression, name: string): AstNodeDescription | undefined { + const found = stream(members).filter(it => it.name === name) + + const classDesc = found.find(it => isClassDeclaration(it.node)) + if (classDesc) { + const clazz = classDesc.node as ClassDeclaration + return this.findOverlaodConstructor(clazz, callExpr) + } + + const methods = found.filter(it => isFunctionDeclaration(it.node) || isExpandFunctionDeclaration(it.node)).toArray() + + const index = this.resolveOverload(methods, it => it.node as CallableDeclaration, callExpr.arguments) + + if (index.length === 0) { + return methods.find(it => it.name === name) + } + return methods[index[0]] + } + + predictCallType(callExpr: CallExpression): Type | undefined { + if (isReferenceExpression(callExpr.receiver)) { + if (callExpr.receiver.target.$nodeDescription) { + return this.typeComputer.inferType(callExpr.receiver.target.$nodeDescription.node) + } + const script = AstUtils.findRootNode(callExpr.receiver) + + if (!isScript(script)) { + return + } + + const refText = callExpr.receiver.target.$refText + + const imports = stream(script.imports) + .map(it => this.descriptionIndex.createImportedDescription(it)) + .flatMap(it => it) + + const overload = this.findOverloadMethod(imports, callExpr, refText) + + if (overload) { + return this.typeComputer.inferType(overload.node) + } + return + } + + if (isMemberAccess(callExpr.receiver)) { + if (callExpr.receiver.target.$nodeDescription) { + return this.typeComputer.inferType(callExpr.receiver.target.$nodeDescription.node) + } + const receiverType = this.typeComputer.inferType(callExpr.receiver.receiver) + const candidates = stream(this.memberProvider.getMembers(receiverType)) + + const overload = this.predictOverloadMethod(candidates, callExpr, callExpr.receiver.target.$refText) + if (overload) { + return this.typeComputer.inferType(overload.node) + } + } + } + + predictOverloadMethod(members: Iterable, callExpr: CallExpression, name: string): AstNodeDescription | undefined { + const found = stream(members).filter(it => it.name === name) + + const classDesc = found.find(it => isClassDeclaration(it.node)) + if (classDesc) { + const ctorDecl = (classDesc.node as ClassDeclaration).members.find((it) => { + if (isConstructorDeclaration(it)) { + return true + } + return this.matchSignature(it as CallableDeclaration, callExpr.arguments.length) !== OverloadMatch.NotMatch + }) + + if (!ctorDecl) { + return + } + return this.descriptionIndex.getDescription(ctorDecl) + } + + return found.filter(it => isFunctionDeclaration(it.node) || isExpandFunctionDeclaration(it.node)) + .find((it) => { + const method = it.node as CallableDeclaration + return this.matchSignature(method, callExpr.arguments.length) !== OverloadMatch.NotMatch + }) + } + + resolveOverload(methods: ArrayLike, supplier: (arg: T) => CallableDeclaration, args: Array): number[] { + const possible = new Set() + + for (let i = 0; i < methods.length; i++) { + const currentMatch = this.matchSignature(supplier(methods[i]), args.length) + if (currentMatch !== OverloadMatch.NotMatch) { + possible.add(i) + } + } + + if (possible.size === 0) { + return [] + } + + if (possible.size === 1) { + return [...possible.values()] + } + + const argTypes = args.map((it) => { + if (isFunctionExpression(it)) { + return new FunctionType([], this.typeComputer.classTypeOf('void')) + } + return this.typeComputer.inferType(it) || this.typeComputer.classTypeOf('any') + }) + + let bestMatch = OverloadMatch.NotMatch + let matchIndexes: number[] = [] + for (let i = 0; i < methods.length; i++) { + if (!possible.has(i)) { + continue + } + const currentMatch = this.matchSignature(supplier(methods[i]), argTypes) + if (currentMatch === OverloadMatch.FullMatch) { + return [i] + } + if (currentMatch < bestMatch) { + matchIndexes = [i] + bestMatch = currentMatch + } + else if (currentMatch === bestMatch) { + // duplicate match + matchIndexes.push(i) + } + } + + if (matchIndexes.length > 1) { + this.logAmbiguousOverload(supplier, methods, argTypes, bestMatch, matchIndexes) + } + + if (bestMatch === OverloadMatch.NotMatch) { + return [] + } + + return matchIndexes + } + + private logAmbiguousOverload(supplier: (arg: T) => CallableDeclaration, methods: ArrayLike, argTypes: Type[], bestMatch: OverloadMatch, matchIndexes: number[]) { + let methodName = '' + if (isConstructorDeclaration(supplier(methods[0]))) { + methodName = (supplier(methods[0]).$container as ClassDeclaration).name + } + else { + methodName = (supplier(methods[0]) as FunctionDeclaration).name + } + + const MATCH_NAMES = ['FullMatch', 'OptionalMatch', 'ImplicitMatch', 'PossibleMatch', 'NotMatch'] + + const argTypeStrings = argTypes.map(it => it?.toString() ?? 'undefined').join(', ') + console.warn(`ambiguous overload for ${methodName} with arguments (${argTypeStrings}), match: ${MATCH_NAMES[bestMatch]}`) + for (const index of matchIndexes) { + const params = (supplier(methods[index]) as CallableDeclaration).parameters.map(it => this.typeComputer.inferType(it)) + .map(it => it?.toString() ?? 'undefined').join(', ') + console.warn(`----- ${methodName} (${params})`) + } + } + + typeIsSame(a: Type | undefined, b: Type | undefined): boolean { + if (a === undefined || b === undefined) { + return false + } + + return a.equals(b) + } + + elementType(type: ClassType): Type | undefined { + const substitutionType = type.substitutions.values().next().value + + if (!substitutionType || isTypeVariable(substitutionType)) { + return + } + + return substitutionType + } + + typeIsFunctionLike(type: Type): boolean { + switch (type.$type) { + case 'ClassType': + { + const members = this.memberProvider.getMembers((type as ClassType).declaration) + return members.some(it => isFunctionDeclaration(it.node) && it.node.prefix === 'lambda') + } + case 'FunctionType': + return true + case 'IntersectionType': + return (type as IntersectionType).types.some(it => this.typeIsFunctionLike(it)) + default: + return false + } + + return false + } + + classIsExtends(a: ClassType, b: ClassType): boolean { + if (a.equals(b)) { + return true + } + + if (b.declaration.name === 'any') { + return true + } + + if (b.declaration.name === 'null' || a.declaration.name === 'void') { + return false + } + + if ((a.declaration.name === 'Array' && b.declaration.name === 'Array') || (a.declaration.name === 'List' && b.declaration.name === 'List')) { + return this.typeIsInstanceOf(this.elementType(a), this.elementType(b)) + } + + if (stream(a.declaration.superTypes) + .map(it => this.typeComputer.inferType(it)) + .filter(it => isClassType(it)) + .some(it => this.classIsExtends(it, b))) { + return true + } + + const castOps = getClassChain(a.declaration) + .flatMap(it => it.members) + .filter(it => isOperatorFunctionDeclaration(it)) + .filter(it => it.op === 'as') + + for (const castOp of castOps) { + const castType = this.typeComputer.inferType(castOp.returnTypeRef) + if (!castType) { + continue + } + let canCast = false + if (isCompoundType(castType)) { + canCast = castType.types.some(it => it.equals(b)) + } + else { + canCast = castType.equals(b) + } + if (canCast) { + return true + } + } + + return false + } + + classIsInstanceOf(a: ClassType, b: Type): boolean { + if (a.declaration.name === 'void') { + return false + } + if (a.declaration.name === 'any' || a.declaration.name === 'null') { + return true + } + switch (b.$type) { + case 'ClassType': + return this.classIsExtends(a as ClassType, b as ClassType) + case 'IntersectionType': + return (b as IntersectionType).types.some(it => this.classIsInstanceOf(a, it)) + case 'UnionType': + return (b as UnionType).types.every(it => this.classIsInstanceOf(a, it)) + default: + return false + } + } + + typeIsInstanceOf(a: Type | undefined, b: Type | undefined): boolean { + if (a === undefined || b === undefined) { + return false + } + + switch (a.$type) { + case 'ClassType': + return this.classIsInstanceOf(a as ClassType, b) + case 'FunctionType': + return this.typeIsFunctionLike(b) + case 'IntersectionType': + return (a as IntersectionType).types.every(it => this.typeIsInstanceOf(it, b)) + case 'UnionType': + return (a as UnionType).types.some(it => this.typeIsInstanceOf(it, b)) + case 'CompoundType': + return false + case 'TypeVariable': + return false + default: + return false + } + } + + matchSignature(method: CallableDeclaration, args: Type[] | number): OverloadMatch { + const paramters = method.parameters + if (isExpandFunctionDeclaration(method)) { + // remove first parameter + paramters.shift() + } + const checkType = Array.isArray(args) + + let match = OverloadMatch.FullMatch + + const argumentLength = checkType ? args.length : args + + let checkLength = Math.min(argumentLength, paramters.length) + + const paramterTypes = checkType ? paramters.map(p => this.typeComputer.inferType(p)) : [] + + if (argumentLength > paramters.length) { + if (!isVarargs(method)) { + return OverloadMatch.NotMatch + } + + if (!checkType) { + return OverloadMatch.PossibleMatch + } + + const varargsType = paramterTypes[paramterTypes.length - 1] + match = OverloadMatch.ImplicitMatch + for (let i = paramters.length - 1; i < argumentLength; i++) { + if (this.typeIsSame(args[i], varargsType)) { + continue + } + match = OverloadMatch.ImplicitMatch + if (!this.typeIsInstanceOf(args[i], varargsType)) { + return OverloadMatch.NotMatch + } + } + } + else if (argumentLength < paramters.length) { + if (!isOptionalArgs(method, argumentLength)) { + return OverloadMatch.NotMatch + } + + if (!checkType) { + return OverloadMatch.PossibleMatch + } + + match = OverloadMatch.OptionalMatch + } + else if (isVarargs(method)) { + if (!checkType) { + return OverloadMatch.PossibleMatch + } + const varargsType = paramterTypes[paramterTypes.length - 1] + if (!varargsType) { + return OverloadMatch.NotMatch + } + + const lastArgType = args.at(-1)! + const varargsTypeArray = this.typeComputer.arrayTypeOf(varargsType) + + if (this.typeIsSame(lastArgType, varargsType) || this.typeIsSame(lastArgType, varargsTypeArray)) { + return OverloadMatch.FullMatch + } + + if (!this.typeIsInstanceOf(lastArgType, varargsType) && !this.typeIsInstanceOf(lastArgType, varargsTypeArray)) { + return OverloadMatch.NotMatch + } + + match = OverloadMatch.ImplicitMatch + checkLength-- + } + + if (!checkType) { + return OverloadMatch.PossibleMatch + } + + for (let i = 0; i < checkLength; i++) { + if (this.typeIsSame(args[i], paramterTypes[i])) { + continue + } + match = OverloadMatch.ImplicitMatch + if (!this.typeIsInstanceOf(args[i], paramterTypes[i])) { + return OverloadMatch.NotMatch + } + } + + return match + } +} diff --git a/packages/zenscript/src/typing/type-computer.ts b/packages/zenscript/src/typing/type-computer.ts index ce79c9df..c41aa845 100644 --- a/packages/zenscript/src/typing/type-computer.ts +++ b/packages/zenscript/src/typing/type-computer.ts @@ -6,11 +6,13 @@ import type { BracketManager } from '../workspace/bracket-manager' import type { PackageManager } from '../workspace/package-manager' import type { BuiltinTypes, Type, TypeParameterSubstitutions } from './type-description' import { type AstNode, stream } from 'langium' -import { isAssignment, isCallExpression, isClassDeclaration, isExpression, isFunctionDeclaration, isFunctionExpression, isOperatorFunctionDeclaration, isTypeParameter, isVariableDeclaration } from '../generated/ast' +import { isAssignment, isCallExpression, isClassDeclaration, isConstructorDeclaration, isExpression, isFunctionDeclaration, isFunctionExpression, isMemberAccess, isOperatorFunctionDeclaration, isReferenceExpression, isTypeParameter, isVariableDeclaration } from '../generated/ast' import { ClassType, CompoundType, FunctionType, IntersectionType, isAnyType, isClassType, isFunctionType, TypeVariable, UnionType } from './type-description' export interface TypeComputer { inferType: (node: AstNode | undefined) => Type | undefined + classTypeOf: (className: BuiltinTypes | string, substitutions?: TypeParameterSubstitutions) => ClassType + arrayTypeOf: (elementType: Type) => ClassType } type SourceMap = ZenScriptAstType & ZenScriptSyntheticAstType @@ -32,7 +34,7 @@ export class ZenScriptTypeComputer implements TypeComputer { return this.rules[node?.$type]?.call(this, node) } - private classTypeOf(className: BuiltinTypes | string, substitutions: TypeParameterSubstitutions = new Map()): ClassType { + public classTypeOf(className: BuiltinTypes | string, substitutions: TypeParameterSubstitutions = new Map()): ClassType { const classDecl = this.classDeclOf(className) if (!classDecl) { throw new Error(`Class "${className}" is not defined.`) @@ -40,6 +42,13 @@ export class ZenScriptTypeComputer implements TypeComputer { return new ClassType(classDecl, substitutions) } + public arrayTypeOf(elementType: Type): ClassType { + const arrayType = this.classTypeOf('Array') + const T = arrayType.declaration.typeParameters[0] + arrayType.substitutions.set(T, elementType) + return arrayType + } + private classDeclOf(className: BuiltinTypes | string): ClassDeclaration | undefined { return stream(this.packageManager.retrieve(className)).find(isClassDeclaration) } @@ -47,10 +56,7 @@ export class ZenScriptTypeComputer implements TypeComputer { private readonly rules: RuleMap = { // region TypeReference ArrayTypeReference: (source) => { - const arrayType = this.classTypeOf('Array') - const T = arrayType.declaration.typeParameters[0] - arrayType.substitutions.set(T, this.inferType(source.value) ?? this.classTypeOf('any')) - return arrayType + return this.arrayTypeOf(this.inferType(source.value) ?? this.classTypeOf('any')) }, ListTypeReference: (source) => { @@ -213,23 +219,55 @@ export class ZenScriptTypeComputer implements TypeComputer { }, PrefixExpression: (source) => { + const exprType = this.inferType(source.expr) + if (isAnyType(exprType)) { + return exprType + } + + const operatorDecl = this.memberProvider().getMembers(exprType) + .map(it => it.node) + .filter(it => isOperatorFunctionDeclaration(it)) + .filter(it => it.op === source.op) + .at(0) + + if (operatorDecl) { + return this.inferType(operatorDecl.returnTypeRef) + } + switch (source.op) { case '-': - return this.classTypeOf('int') + return exprType ?? this.classTypeOf('int') case '!': - return this.classTypeOf('bool') + return exprType ?? this.classTypeOf('bool') } }, InfixExpression: (source) => { // TODO: operator overloading + const leftType = this.inferType(source.left) + const rightType = this.inferType(source.right) + + if (isAnyType(leftType) || isAnyType(rightType)) { + return this.classTypeOf('any') + } + + const operatorDecl = this.memberProvider().getMembers(leftType) + .map(it => it.node) + .filter(it => isOperatorFunctionDeclaration(it)) + .filter(it => it.op === source.op) + .at(0) + + if (operatorDecl) { + return this.inferType(operatorDecl.returnTypeRef) + } + switch (source.op) { case '+': case '-': case '*': case '/': case '%': - return this.classTypeOf('int') + return leftType ?? this.classTypeOf('int') case '<': case '>': case '<=': @@ -286,6 +324,15 @@ export class ZenScriptTypeComputer implements TypeComputer { }, FunctionExpression: (source) => { + // dynamic arguments + if (isCallExpression(source.$container) && source.$containerProperty === 'arguments' && source.$containerIndex !== undefined) { + const callType = this.inferType(source.$container.receiver) + if (!isFunctionType(callType)) { + return + } + + return callType.paramTypes.at(source.$containerIndex) + } const paramTypes = source.parameters.map(param => this.inferType(param) ?? this.classTypeOf('any')) const returnType = this.inferType(source.returnTypeRef) ?? this.classTypeOf('any') return new FunctionType(paramTypes, returnType) @@ -306,7 +353,7 @@ export class ZenScriptTypeComputer implements TypeComputer { }, MemberAccess: (source) => { - const targetContainer = source.target.ref?.$container + const targetContainer = source.target?.ref?.$container if (isOperatorFunctionDeclaration(targetContainer) && targetContainer.op === '.') { return this.inferType(targetContainer.returnTypeRef) } @@ -338,6 +385,17 @@ export class ZenScriptTypeComputer implements TypeComputer { }, CallExpression: (source) => { + if (isReferenceExpression(source.receiver) || isMemberAccess(source.receiver)) { + const receiverRef = source.receiver.target.ref + if (!receiverRef) { + return + } + + if (isConstructorDeclaration(receiverRef) && isClassDeclaration(receiverRef.$container)) { + return new ClassType(receiverRef.$container, new Map()) + } + } + const receiverType = this.inferType(source.receiver) if (isFunctionType(receiverType)) { return receiverType.returnType diff --git a/packages/zenscript/src/typing/type-description.ts b/packages/zenscript/src/typing/type-description.ts index 4ca7a8ad..e69cf7eb 100644 --- a/packages/zenscript/src/typing/type-description.ts +++ b/packages/zenscript/src/typing/type-description.ts @@ -24,6 +24,10 @@ export abstract class Type { abstract substituteTypeParameters(substitutions: TypeParameterSubstitutions): Type abstract toString(): string + + equals(other: Type): boolean { + return this.$type === other.$type + } } export abstract class NamedType extends Type { @@ -33,6 +37,10 @@ export abstract class NamedType extends Type { super($type) this.declaration = declaration } + + override equals(other: Type): boolean { + return super.equals(other) && this.declaration === (other as NamedType).declaration + } } export class ClassType extends NamedType { @@ -56,6 +64,23 @@ export class ClassType extends NamedType { } return result } + + override equals(other: Type): boolean { + if (!super.equals(other)) { + return false + } + if (!isClassType(other) || this.substitutions.size !== other.substitutions.size) { + return false + } + + for (const [key, value] of this.substitutions) { + if (!other.substitutions.has(key) || !value.equals(other.substitutions.get(key)!)) { + return false + } + } + + return true + } } export class TypeVariable extends NamedType { @@ -70,6 +95,10 @@ export class TypeVariable extends NamedType { override toString(): string { return this.declaration.name } + + override equals(other: Type): boolean { + return super.equals(other) && isTypeVariable(other) + } } export class FunctionType extends Type { @@ -96,6 +125,21 @@ export class FunctionType extends Type { result += this.returnType.toString() return result } + + override equals(other: Type): boolean { + if (!super.equals(other) || !isFunctionType(other)) { + return false + } + if (this.paramTypes.length !== other.paramTypes.length) { + return false + } + for (let i = 0; i < this.paramTypes.length; i++) { + if (!this.paramTypes[i].equals(other.paramTypes[i])) { + return false + } + } + return this.returnType.equals(other.returnType) + } } export class UnionType extends Type { @@ -112,6 +156,21 @@ export class UnionType extends Type { override toString(): string { return this.types.map(it => it.toString()).join(' | ') } + + override equals(other: Type): boolean { + if (!super.equals(other) || !isUnionType(other)) { + return false + } + if (this.types.length !== other.types.length) { + return false + } + for (let i = 0; i < this.types.length; i++) { + if (!this.types[i].equals(other.types[i])) { + return false + } + } + return true + } } export class IntersectionType extends Type { @@ -128,6 +187,21 @@ export class IntersectionType extends Type { override toString(): string { return this.types.map(it => it.toString()).join(' & ') } + + override equals(other: Type): boolean { + if (!super.equals(other) || !isIntersectionType(other)) { + return false + } + if (this.types.length !== other.types.length) { + return false + } + for (let i = 0; i < this.types.length; i++) { + if (!this.types[i].equals(other.types[i])) { + return false + } + } + return true + } } export class CompoundType extends Type { @@ -144,6 +218,21 @@ export class CompoundType extends Type { override toString(): string { return this.types.map(it => it.toString()).join(', ') } + + override equals(other: Type): boolean { + if (!super.equals(other) || !isCompoundType(other)) { + return false + } + if (this.types.length !== other.types.length) { + return false + } + for (let i = 0; i < this.types.length; i++) { + if (!this.types[i].equals(other.types[i])) { + return false + } + } + return true + } } // endregion @@ -219,4 +308,8 @@ export function isIntersectionType(type: unknown): type is IntersectionType { export function isTypeVariable(type: unknown): type is TypeVariable { return type instanceof TypeVariable } + +export function isCompoundType(type: unknown): type is CompoundType { + return type instanceof CompoundType +} // endregion diff --git a/packages/zenscript/src/workspace/description-index.ts b/packages/zenscript/src/workspace/description-index.ts index 230125f7..f08304d2 100644 --- a/packages/zenscript/src/workspace/description-index.ts +++ b/packages/zenscript/src/workspace/description-index.ts @@ -1,15 +1,17 @@ import type { HierarchyNode } from '@intellizen/shared' -import type { AstNode, AstNodeDescription, AstNodeDescriptionProvider, NameProvider } from 'langium' -import type { ClassDeclaration, ImportDeclaration } from '../generated/ast' import type { ZenScriptServices } from '../module' +import type { ZenScriptPackageManager } from './package-manager' +import { type AstNode, type AstNodeDescription, type AstNodeDescriptionProvider, type NameProvider, stream } from 'langium' +import { type ClassDeclaration, type ImportDeclaration, isClassDeclaration, isFunctionDeclaration } from '../generated/ast' import { createSyntheticAstNodeDescription } from '../reference/synthetic' +import { isStatic } from '../utils/ast' export interface DescriptionIndex { getDescription: (astNode: AstNode) => AstNodeDescription getPackageDescription: (pkgNode: HierarchyNode) => AstNodeDescription getThisDescription: (classDecl: ClassDeclaration) => AstNodeDescription createDynamicDescription: (astNode: AstNode, name: string) => AstNodeDescription - createImportedDescription: (importDecl: ImportDeclaration) => AstNodeDescription + createImportedDescription: (importDecl: ImportDeclaration) => AstNodeDescription[] } export class ZenScriptDescriptionIndex implements DescriptionIndex { @@ -20,12 +22,15 @@ export class ZenScriptDescriptionIndex implements DescriptionIndex { readonly pkgDescriptions: WeakMap, AstNodeDescription> readonly thisDescriptions: WeakMap + private readonly packageManager: ZenScriptPackageManager + constructor(services: ZenScriptServices) { this.descriptions = services.workspace.AstNodeDescriptionProvider this.nameProvider = services.references.NameProvider this.astDescriptions = new WeakMap() this.pkgDescriptions = new WeakMap() this.thisDescriptions = new WeakMap() + this.packageManager = services.workspace.PackageManager } getDescription(astNode: AstNode): AstNodeDescription { @@ -56,8 +61,49 @@ export class ZenScriptDescriptionIndex implements DescriptionIndex { return this.descriptions.createDescription(astNode, name) } - createImportedDescription(importDecl: ImportDeclaration): AstNodeDescription { - const ref = importDecl.path.at(-1)?.ref ?? importDecl - return this.descriptions.createDescription(ref, this.nameProvider.getName(importDecl)) + createImportedDescription(importDecl: ImportDeclaration): AstNodeDescription[] { + const ref = importDecl.path.at(-1)?.ref + if (!ref) { + return [this.getDescription(importDecl)] + } + const alias = importDecl.alias + + // handle import overloading + if (isFunctionDeclaration(ref)) { + // Find function with same name in the same package + const parentRef = importDecl.path.at(-2)?.ref + if (!parentRef) { + return [] + } + + if (isClassDeclaration(parentRef)) { + const result = stream(parentRef.members) + .filter(it => isFunctionDeclaration(it)) + .filter(it => it.name === ref.name) + .filter(it => isStatic(it)) + .map(it => this.getDescription(it)) + .map(it => alias ? this.aliasDescription(it, alias) : it) + .toArray() + return result + } + } + + const targetDesc = importDecl.path.at(-1)?.$nodeDescription || this.descriptions.createDescription(importDecl, this.nameProvider.getName(importDecl)) + return [alias ? this.aliasDescription(targetDesc, alias) : targetDesc] + } + + private aliasDescription(desc: AstNodeDescription, alias: string): AstNodeDescription { + const { node, selectionSegment, type, documentUri, path } = desc + return { + node, + name: alias, + get nameSegment() { + return desc.nameSegment + }, + selectionSegment, + type, + documentUri, + path, + } } } diff --git a/packages/zenscript/src/zenscript.langium b/packages/zenscript/src/zenscript.langium index e0b7f8ab..e276de08 100644 --- a/packages/zenscript/src/zenscript.langium +++ b/packages/zenscript/src/zenscript.langium @@ -28,7 +28,7 @@ interface ClassDeclaration extends Declaration { members: ClassMemberDeclaration[]; } -type NamedElement = Script | ClassDeclaration | FunctionDeclaration | ExpandFunctionDeclaration | FieldDeclaration | ValueParameter| VariableDeclaration | LoopParameter | MapEntry | ImportDeclaration; +type NamedElement = Script | ClassDeclaration | FunctionDeclaration | ExpandFunctionDeclaration | ConstructorDeclaration | FieldDeclaration | ValueParameter| VariableDeclaration | LoopParameter | MapEntry | ImportDeclaration; type ClassMemberDeclaration = FunctionDeclaration | FieldDeclaration | ConstructorDeclaration | OperatorFunctionDeclaration;