Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/cute-chairs-kneel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"effect": minor
---

Add Function.memo
191 changes: 191 additions & 0 deletions packages/effect/src/Function.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,194 @@ export const hole: <T>() => T = unsafeCoerce(absurd)
* @since 2.0.0
*/
export const SK = <A, B>(_: A, b: B): B => b

/**
* Empty ArgNode result sentinel.
*
* @internal
*/
const EMPTY = Symbol.for("effect/Function.memo")

/**
* Memo helper class.
*
* @internal
*/
class ArgNode<A> {
/** The cached result of the args that led to this node, or `EMPTY`. */
r: A | typeof EMPTY = EMPTY

/** The primitive arg branch. */
private p: Map<unknown, ArgNode<A>> | null = null

/** The object arg branch. */
private o: WeakMap<object, ArgNode<A>> | null = null

/** Get the next node for an arg. If uninitialized, one will be created, set, and returned. */
at(arg: unknown): ArgNode<A> {
let next: ArgNode<A> | undefined
let isObject: boolean
switch (typeof arg) {
case "object":
// @ts-expect-error Fallthrough intended for primitive null case
case "function":
if (arg !== null) {
if (!this.o) this.o = new WeakMap()
next = this.o.get(arg)
isObject = true
break
}

default:
if (!this.p) this.p = new Map()
next = this.p.get(arg)
isObject = false
}

if (next) return next

const fresh = new ArgNode<A>()

if (isObject) {
this.o!.set(arg as object, fresh)
} else {
this.p!.set(arg, fresh)
}

return fresh
}
}

/**
* A global WeakMap from original functions to memo wrappers, so that we can avoid re-wrapping.
*
* @internal
*/
const origToMemo = new WeakMap<Function, Function>()
/**
* Private branding for memo wrappers.
*
* @internal
*/
const allMemos = new WeakSet<Function>()

export declare namespace memo {
interface Options {
/**
* Optimizes caching by not setting the path for trailing `undefined`
* arguments. This is useful for functions with optional arguments, treating
* provided-`undefined` and absent-`undefined` the same way. The wrapped
* function will nevertheless receive all arguments.
*
* This option should be disabled if:
*
* - The function has variadic arguments and `undefined` is a valid input
* - The function distinguishes `f(1)` vs `f(1, undefined)` via `arguments.length` or `rest.length`
*
* @default true
*/
readonly trimUndefined?: boolean
}
}

/**
* Memoize a function, method, or getter, with any number of arguments. Repeated calls to the returned function with the same arguments will return cached values.
*
* Usage notes:
* - Memoized functions should be totally pure, and should return immutable values.
* - The cache size is unbounded, but internally a `WeakMap` is used when possible. To make the most of this, memoized functions should have object-type args at the start and primitive args at the end.
* - Works as a class method decorator under modern settings (`experimentalDecorators: false`), though you will have to use the curried form (`@Function.memo()`).
*
* @example
* ```ts
* import { memo } from "effect/Function"
*
* const add = memo((x: number, y: number) => {
* console.log("running add");
* return x + y;
* });
*
* add(2, 3); // logs "running add", returns 5
* add(2, 3); // no log, returns cached 5
* add(2, 4); // logs "running add", returns 6
*
* // Expected console output:
* // running add
* // running add
* ```
*
* @since 3.20.0
*/
export const memo: {
(
options?: memo.Options
): <Args extends Array<any>, Return>(fn: (...args: Args) => Return) => (...args: Args) => Return
<Args extends Array<any>, Return>(
fn: (...args: Args) => Return,
options?: memo.Options
): (...args: Args) => Return
} = dual(
(args) => isFunction(args[0]),
<Args extends Array<any>, Return>(
fn: (...args: Args) => Return,
options?: memo.Options
): (...args: Args) => Return => {
// If input is a 'base' function and already memoized, return it
if (origToMemo.has(fn)) return origToMemo.get(fn) as any

// If input is a memo function, don't re-wrap it
if (allMemos.has(fn)) return fn

// Create the `ArgNode` root for this function
const root = new ArgNode<Return>()

const shouldTrim = options?.trimUndefined ?? true

const out = function(this: unknown, ...args: Args): Return {
// Find the defined length
let argsLength = args.length
if (shouldTrim) {
let definedLength = argsLength
while (definedLength > 0 && args[definedLength - 1] === undefined) {
definedLength -= 1
}
if (definedLength < argsLength) {
argsLength = definedLength
}
}

// Drill through `this` and `args` to get the ArgNode that holds the result
let node = root.at(this)
for (let i = 0; i < argsLength; i += 1) {
node = node.at(args[i])
}

if (node.r !== EMPTY) return node.r

const result = fn.apply(this, args)
node.r = result

return result
}

origToMemo.set(fn, out)
allMemos.add(out)

return out
}
)

/**
* See {@link memo}. This is an alias that infers `This` and uses it in the returned function signature.
*
* @since 3.20.0
*/
export const memoThis: {
(options?: memo.Options): <This, Args extends Array<any>, Return>(
fn: (this: This, ...args: Args) => Return
) => (this: This, ...args: Args) => Return
<This, Args extends Array<any>, Return>(
fn: (this: This, ...args: Args) => Return,
options?: memo.Options
): (this: This, ...args: Args) => Return
} = memo
130 changes: 129 additions & 1 deletion packages/effect/test/Function.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, it } from "@effect/vitest"
import { describe, expect, it, vi } from "@effect/vitest"
import { deepStrictEqual, strictEqual, throws } from "@effect/vitest/utils"
import { Function, String } from "effect"

Expand Down Expand Up @@ -195,4 +195,132 @@ describe("Function", () => {
deepStrictEqual(f.apply(null, ["_", "a", "b", "c", "d", "e", "f"] as any), "_abcde")
})
})

describe("memo", () => {
it("memoizes by primitive argument list", () => {
const fn = vi.fn((a: number, b: number) => a + b)
const memoFn = Function.memo(fn)

strictEqual(memoFn(1, 2), 3)
expect(fn).toHaveBeenCalledTimes(1)

// Cache hit (same args)
strictEqual(memoFn(1, 2), 3)
expect(fn).toHaveBeenCalledTimes(1)

// Different args -> cache miss
strictEqual(memoFn(2, 1), 3)
expect(fn).toHaveBeenCalledTimes(2)
})

it("memoizes zero-argument functions and shares cache per fn", () => {
const fn = vi.fn(() => 42)
const memoFnA = Function.memo(fn)
const memoFnB = Function.memo(fn) // wrappers share same per-fn cache

expect(memoFnA()).toBe(42)
expect(fn).toHaveBeenCalledTimes(1)

// Cache hit via same wrapper
expect(memoFnA()).toBe(42)
expect(fn).toHaveBeenCalledTimes(1)

// Cache hit via different wrapper for same original function
expect(memoFnB()).toBe(42)
expect(fn).toHaveBeenCalledTimes(1)
})

it("trims trailing undefined optional args by default", () => {
const fn = vi.fn(function(a?: number, b?: number) {
return { a, b, length: arguments.length }
})
const m = Function.memo(fn)

expect(m(1)).toEqual({ a: 1, b: undefined, length: 1 })
expect(fn).toHaveBeenCalledTimes(1)

expect(m(1, undefined)).toEqual({ a: 1, b: undefined, length: 1 })
expect(fn).toHaveBeenCalledTimes(1)

expect(m(1, 2)).toEqual({ a: 1, b: 2, length: 2 })
expect(fn).toHaveBeenCalledTimes(2)
})

it("preserves trailing undefined when trimUndefined is disabled", () => {
const fn = vi.fn(function(a?: number, b?: number) {
return { a, b, length: arguments.length }
})
const m = Function.memo(fn, { trimUndefined: false })

expect(m(1)).toEqual({ a: 1, b: undefined, length: 1 })
expect(fn).toHaveBeenCalledTimes(1)

expect(m(1, undefined)).toEqual({ a: 1, b: undefined, length: 2 })
expect(fn).toHaveBeenCalledTimes(2)

expect(m(1)).toEqual({ a: 1, b: undefined, length: 1 })
expect(fn).toHaveBeenCalledTimes(2)
})

it("only trims trailing undefined values", () => {
const fn = vi.fn((a?: number, b?: number, c?: number) => [a, b, c])
const m = Function.memo(fn)

expect(m(1, undefined, 3)).toEqual([1, undefined, 3])
expect(fn).toHaveBeenCalledTimes(1)

expect(m(1, 3)).toEqual([1, 3, undefined])
expect(fn).toHaveBeenCalledTimes(2)
})

it("memoizes undefined results distinctly", () => {
const fn = vi.fn((_: number): void => undefined)
const memoFn = Function.memo(fn)

expect(memoFn(123)).toBeUndefined()
expect(fn).toHaveBeenCalledTimes(1)

// Cache hit despite undefined result
expect(memoFn(123)).toBeUndefined()
expect(fn).toHaveBeenCalledTimes(1)
})

it("uses object identity for object args", () => {
const a = { x: 1 }
const b = { x: 1 }

const fn = vi.fn((o: { x: number }) => o.x)
const memoFn = Function.memo(fn)

expect(memoFn(a)).toBe(1)
expect(fn).toHaveBeenCalledTimes(1)

// Cache hit for same object instance
expect(memoFn(a)).toBe(1)
expect(fn).toHaveBeenCalledTimes(1)

// Different object instance with same shape -> cache miss
expect(memoFn(b)).toBe(1)
expect(fn).toHaveBeenCalledTimes(2)
})

it("includes receiver (this) in the cache key", () => {
function plusOffset(this: { offset: number }, x: number) {
return x + this.offset
}
const fnThis = vi.fn(plusOffset)
// Casts only needed to appease TS 'this' typing for test invocations
const memoFn = Function.memo(fnThis)

const a = { offset: 1 }
const b = { offset: 100 }

expect(memoFn.call(a, 1)).toBe(2)
expect(fnThis).toHaveBeenCalledTimes(1)

// Different receiver with same args should not reuse prior result
expect(memoFn.call(b, 1)).toBe(101)
expect(fnThis).toHaveBeenCalledTimes(2)
})
})
})