diff --git a/.changeset/cute-chairs-kneel.md b/.changeset/cute-chairs-kneel.md new file mode 100644 index 00000000000..e00e03eced1 --- /dev/null +++ b/.changeset/cute-chairs-kneel.md @@ -0,0 +1,5 @@ +--- +"effect": minor +--- + +Add Function.memo diff --git a/packages/effect/src/Function.ts b/packages/effect/src/Function.ts index 69d53fd2fdd..03393376ecc 100644 --- a/packages/effect/src/Function.ts +++ b/packages/effect/src/Function.ts @@ -1220,3 +1220,194 @@ export const hole: () => T = unsafeCoerce(absurd) * @since 2.0.0 */ export const SK = (_: A, b: B): B => b + +/** + * Empty ArgNode result sentinel. + * + * @internal + */ +const EMPTY = Symbol.for("effect/Function.memo") + +/** + * Memo helper class. + * + * @internal + */ +class ArgNode { + /** The cached result of the args that led to this node, or `EMPTY`. */ + r: A | typeof EMPTY = EMPTY + + /** The primitive arg branch. */ + private p: Map> | null = null + + /** The object arg branch. */ + private o: WeakMap> | null = null + + /** Get the next node for an arg. If uninitialized, one will be created, set, and returned. */ + at(arg: unknown): ArgNode { + let next: ArgNode | undefined + let isObject: boolean + switch (typeof arg) { + case "object": + // @ts-expect-error Fallthrough intended for primitive null case + case "function": + if (arg !== null) { + if (!this.o) this.o = new WeakMap() + next = this.o.get(arg) + isObject = true + break + } + + default: + if (!this.p) this.p = new Map() + next = this.p.get(arg) + isObject = false + } + + if (next) return next + + const fresh = new ArgNode() + + if (isObject) { + this.o!.set(arg as object, fresh) + } else { + this.p!.set(arg, fresh) + } + + return fresh + } +} + +/** + * A global WeakMap from original functions to memo wrappers, so that we can avoid re-wrapping. + * + * @internal + */ +const origToMemo = new WeakMap() +/** + * Private branding for memo wrappers. + * + * @internal + */ +const allMemos = new WeakSet() + +export declare namespace memo { + interface Options { + /** + * Optimizes caching by not setting the path for trailing `undefined` + * arguments. This is useful for functions with optional arguments, treating + * provided-`undefined` and absent-`undefined` the same way. The wrapped + * function will nevertheless receive all arguments. + * + * This option should be disabled if: + * + * - The function has variadic arguments and `undefined` is a valid input + * - The function distinguishes `f(1)` vs `f(1, undefined)` via `arguments.length` or `rest.length` + * + * @default true + */ + readonly trimUndefined?: boolean + } +} + +/** + * Memoize a function, method, or getter, with any number of arguments. Repeated calls to the returned function with the same arguments will return cached values. + * + * Usage notes: + * - Memoized functions should be totally pure, and should return immutable values. + * - The cache size is unbounded, but internally a `WeakMap` is used when possible. To make the most of this, memoized functions should have object-type args at the start and primitive args at the end. + * - Works as a class method decorator under modern settings (`experimentalDecorators: false`), though you will have to use the curried form (`@Function.memo()`). + * + * @example + * ```ts + * import { memo } from "effect/Function" + * + * const add = memo((x: number, y: number) => { + * console.log("running add"); + * return x + y; + * }); + * + * add(2, 3); // logs "running add", returns 5 + * add(2, 3); // no log, returns cached 5 + * add(2, 4); // logs "running add", returns 6 + * + * // Expected console output: + * // running add + * // running add + * ``` + * + * @since 3.20.0 + */ +export const memo: { + ( + options?: memo.Options + ): , Return>(fn: (...args: Args) => Return) => (...args: Args) => Return + , Return>( + fn: (...args: Args) => Return, + options?: memo.Options + ): (...args: Args) => Return +} = dual( + (args) => isFunction(args[0]), + , Return>( + fn: (...args: Args) => Return, + options?: memo.Options + ): (...args: Args) => Return => { + // If input is a 'base' function and already memoized, return it + if (origToMemo.has(fn)) return origToMemo.get(fn) as any + + // If input is a memo function, don't re-wrap it + if (allMemos.has(fn)) return fn + + // Create the `ArgNode` root for this function + const root = new ArgNode() + + const shouldTrim = options?.trimUndefined ?? true + + const out = function(this: unknown, ...args: Args): Return { + // Find the defined length + let argsLength = args.length + if (shouldTrim) { + let definedLength = argsLength + while (definedLength > 0 && args[definedLength - 1] === undefined) { + definedLength -= 1 + } + if (definedLength < argsLength) { + argsLength = definedLength + } + } + + // Drill through `this` and `args` to get the ArgNode that holds the result + let node = root.at(this) + for (let i = 0; i < argsLength; i += 1) { + node = node.at(args[i]) + } + + if (node.r !== EMPTY) return node.r + + const result = fn.apply(this, args) + node.r = result + + return result + } + + origToMemo.set(fn, out) + allMemos.add(out) + + return out + } +) + +/** + * See {@link memo}. This is an alias that infers `This` and uses it in the returned function signature. + * + * @since 3.20.0 + */ +export const memoThis: { + (options?: memo.Options): , Return>( + fn: (this: This, ...args: Args) => Return + ) => (this: This, ...args: Args) => Return + , Return>( + fn: (this: This, ...args: Args) => Return, + options?: memo.Options + ): (this: This, ...args: Args) => Return +} = memo diff --git a/packages/effect/test/Function.test.ts b/packages/effect/test/Function.test.ts index de67d89dc5a..8e08a1ca0f2 100644 --- a/packages/effect/test/Function.test.ts +++ b/packages/effect/test/Function.test.ts @@ -1,4 +1,4 @@ -import { describe, it } from "@effect/vitest" +import { describe, expect, it, vi } from "@effect/vitest" import { deepStrictEqual, strictEqual, throws } from "@effect/vitest/utils" import { Function, String } from "effect" @@ -195,4 +195,132 @@ describe("Function", () => { deepStrictEqual(f.apply(null, ["_", "a", "b", "c", "d", "e", "f"] as any), "_abcde") }) }) + + describe("memo", () => { + it("memoizes by primitive argument list", () => { + const fn = vi.fn((a: number, b: number) => a + b) + const memoFn = Function.memo(fn) + + strictEqual(memoFn(1, 2), 3) + expect(fn).toHaveBeenCalledTimes(1) + + // Cache hit (same args) + strictEqual(memoFn(1, 2), 3) + expect(fn).toHaveBeenCalledTimes(1) + + // Different args -> cache miss + strictEqual(memoFn(2, 1), 3) + expect(fn).toHaveBeenCalledTimes(2) + }) + + it("memoizes zero-argument functions and shares cache per fn", () => { + const fn = vi.fn(() => 42) + const memoFnA = Function.memo(fn) + const memoFnB = Function.memo(fn) // wrappers share same per-fn cache + + expect(memoFnA()).toBe(42) + expect(fn).toHaveBeenCalledTimes(1) + + // Cache hit via same wrapper + expect(memoFnA()).toBe(42) + expect(fn).toHaveBeenCalledTimes(1) + + // Cache hit via different wrapper for same original function + expect(memoFnB()).toBe(42) + expect(fn).toHaveBeenCalledTimes(1) + }) + + it("trims trailing undefined optional args by default", () => { + const fn = vi.fn(function(a?: number, b?: number) { + return { a, b, length: arguments.length } + }) + const m = Function.memo(fn) + + expect(m(1)).toEqual({ a: 1, b: undefined, length: 1 }) + expect(fn).toHaveBeenCalledTimes(1) + + expect(m(1, undefined)).toEqual({ a: 1, b: undefined, length: 1 }) + expect(fn).toHaveBeenCalledTimes(1) + + expect(m(1, 2)).toEqual({ a: 1, b: 2, length: 2 }) + expect(fn).toHaveBeenCalledTimes(2) + }) + + it("preserves trailing undefined when trimUndefined is disabled", () => { + const fn = vi.fn(function(a?: number, b?: number) { + return { a, b, length: arguments.length } + }) + const m = Function.memo(fn, { trimUndefined: false }) + + expect(m(1)).toEqual({ a: 1, b: undefined, length: 1 }) + expect(fn).toHaveBeenCalledTimes(1) + + expect(m(1, undefined)).toEqual({ a: 1, b: undefined, length: 2 }) + expect(fn).toHaveBeenCalledTimes(2) + + expect(m(1)).toEqual({ a: 1, b: undefined, length: 1 }) + expect(fn).toHaveBeenCalledTimes(2) + }) + + it("only trims trailing undefined values", () => { + const fn = vi.fn((a?: number, b?: number, c?: number) => [a, b, c]) + const m = Function.memo(fn) + + expect(m(1, undefined, 3)).toEqual([1, undefined, 3]) + expect(fn).toHaveBeenCalledTimes(1) + + expect(m(1, 3)).toEqual([1, 3, undefined]) + expect(fn).toHaveBeenCalledTimes(2) + }) + + it("memoizes undefined results distinctly", () => { + const fn = vi.fn((_: number): void => undefined) + const memoFn = Function.memo(fn) + + expect(memoFn(123)).toBeUndefined() + expect(fn).toHaveBeenCalledTimes(1) + + // Cache hit despite undefined result + expect(memoFn(123)).toBeUndefined() + expect(fn).toHaveBeenCalledTimes(1) + }) + + it("uses object identity for object args", () => { + const a = { x: 1 } + const b = { x: 1 } + + const fn = vi.fn((o: { x: number }) => o.x) + const memoFn = Function.memo(fn) + + expect(memoFn(a)).toBe(1) + expect(fn).toHaveBeenCalledTimes(1) + + // Cache hit for same object instance + expect(memoFn(a)).toBe(1) + expect(fn).toHaveBeenCalledTimes(1) + + // Different object instance with same shape -> cache miss + expect(memoFn(b)).toBe(1) + expect(fn).toHaveBeenCalledTimes(2) + }) + + it("includes receiver (this) in the cache key", () => { + function plusOffset(this: { offset: number }, x: number) { + return x + this.offset + } + const fnThis = vi.fn(plusOffset) + // Casts only needed to appease TS 'this' typing for test invocations + const memoFn = Function.memo(fnThis) + + const a = { offset: 1 } + const b = { offset: 100 } + + expect(memoFn.call(a, 1)).toBe(2) + expect(fnThis).toHaveBeenCalledTimes(1) + + // Different receiver with same args should not reuse prior result + expect(memoFn.call(b, 1)).toBe(101) + expect(fnThis).toHaveBeenCalledTimes(2) + }) + }) })