diff --git a/packages/react/src/alert-dialog/popup/AlertDialogPopup.tsx b/packages/react/src/alert-dialog/popup/AlertDialogPopup.tsx index 79a967b496b..22a9fc48f37 100644 --- a/packages/react/src/alert-dialog/popup/AlertDialogPopup.tsx +++ b/packages/react/src/alert-dialog/popup/AlertDialogPopup.tsx @@ -104,7 +104,7 @@ export const AlertDialogPopup = React.forwardRef(function AlertDialogPopup( }, elementProps, ], - ref: [forwardedRef, store.context.popupRef, store.getElementSetter('popupElement')], + ref: [forwardedRef, store.context.popupRef, store.useStateSetter('popupElement')], stateAttributesMapping, }); diff --git a/packages/react/src/alert-dialog/viewport/AlertDialogViewport.tsx b/packages/react/src/alert-dialog/viewport/AlertDialogViewport.tsx index 74ba2216313..a56e4c8f699 100644 --- a/packages/react/src/alert-dialog/viewport/AlertDialogViewport.tsx +++ b/packages/react/src/alert-dialog/viewport/AlertDialogViewport.tsx @@ -59,7 +59,7 @@ export const AlertDialogViewport = React.forwardRef(function AlertDialogViewport return useRenderElement('div', componentProps, { enabled: shouldRender, state, - ref: [forwardedRef, store.getElementSetter('viewportElement')], + ref: [forwardedRef, store.useStateSetter('viewportElement')], stateAttributesMapping, props: [ { diff --git a/packages/react/src/dialog/popup/DialogPopup.tsx b/packages/react/src/dialog/popup/DialogPopup.tsx index cbc1d1475e3..8a3e459aff7 100644 --- a/packages/react/src/dialog/popup/DialogPopup.tsx +++ b/packages/react/src/dialog/popup/DialogPopup.tsx @@ -107,7 +107,7 @@ export const DialogPopup = React.forwardRef(function DialogPopup( }, elementProps, ], - ref: [forwardedRef, store.context.popupRef, store.getElementSetter('popupElement')], + ref: [forwardedRef, store.context.popupRef, store.useStateSetter('popupElement')], stateAttributesMapping, }); diff --git a/packages/react/src/dialog/viewport/DialogViewport.tsx b/packages/react/src/dialog/viewport/DialogViewport.tsx index 797e27456ab..ca9e32293aa 100644 --- a/packages/react/src/dialog/viewport/DialogViewport.tsx +++ b/packages/react/src/dialog/viewport/DialogViewport.tsx @@ -59,7 +59,7 @@ export const DialogViewport = React.forwardRef(function DialogViewport( return useRenderElement('div', componentProps, { enabled: shouldRender, state, - ref: [forwardedRef, store.getElementSetter('viewportElement')], + ref: [forwardedRef, store.useStateSetter('viewportElement')], stateAttributesMapping, props: [ { diff --git a/packages/react/src/tooltip/popup/TooltipPopup.tsx b/packages/react/src/tooltip/popup/TooltipPopup.tsx index c0eb7bab58b..91d63c5c950 100644 --- a/packages/react/src/tooltip/popup/TooltipPopup.tsx +++ b/packages/react/src/tooltip/popup/TooltipPopup.tsx @@ -95,7 +95,7 @@ export const TooltipPopup = React.forwardRef(function TooltipPopup( const element = useRenderElement('div', componentProps, { state, - ref: [forwardedRef, store.context.popupRef, store.getElementSetter('popupElement')], + ref: [forwardedRef, store.context.popupRef, store.useStateSetter('popupElement')], props: [ popupProps, transitionStatus === 'starting' ? DISABLED_TRANSITIONS_STYLE : EMPTY_OBJECT, diff --git a/packages/react/src/tooltip/positioner/TooltipPositioner.tsx b/packages/react/src/tooltip/positioner/TooltipPositioner.tsx index 17f4175a0e2..abae67637df 100644 --- a/packages/react/src/tooltip/positioner/TooltipPositioner.tsx +++ b/packages/react/src/tooltip/positioner/TooltipPositioner.tsx @@ -123,7 +123,7 @@ export const TooltipPositioner = React.forwardRef(function TooltipPositioner( const element = useRenderElement('div', componentProps, { state, props: [positioner.props, elementProps], - ref: [forwardedRef, store.getElementSetter('positionerElement')], + ref: [forwardedRef, store.useStateSetter('positionerElement')], stateAttributesMapping: popupStateMapping, }); diff --git a/packages/react/test/index.ts b/packages/react/test/index.ts index a4798274ab4..d9f66a27903 100644 --- a/packages/react/test/index.ts +++ b/packages/react/test/index.ts @@ -1,4 +1,4 @@ +export * from '@base-ui-components/utils/testUtils'; export { createRenderer } from './createRenderer'; export { describeConformance } from './describeConformance'; export { popupConformanceTests } from './popupConformanceTests'; -export * from './utils'; diff --git a/packages/utils/src/store/ReactStore.spec.ts b/packages/utils/src/store/ReactStore.spec.ts new file mode 100644 index 00000000000..3b11d01a416 --- /dev/null +++ b/packages/utils/src/store/ReactStore.spec.ts @@ -0,0 +1,96 @@ +import { expectType } from '../testUtils'; +import { createSelector } from './createSelector'; +import { ReactStore } from './ReactStore'; + +interface TestState { + count: number | undefined; + text: string; +} + +const selectors = { + count: createSelector((state: TestState) => state.count), + text: createSelector((state: TestState) => state.text), + textLongerThan(state: TestState, length: number) { + return state.text.length > length; + }, + textLengthBetween(state: TestState, minLength: number, maxLength: number) { + return state.text.length >= minLength && state.text.length <= maxLength; + }, +}; + +const store = new ReactStore, typeof selectors>( + { count: 0, text: '' }, + undefined, + selectors, +); + +const count = store.select('count'); +expectType(count); + +const text = store.select('text'); +expectType(text); + +const isTextLongerThan5 = store.select('textLongerThan', 5); +expectType(isTextLongerThan5); + +const isTextLengthBetween3And10 = store.select('textLengthBetween', 3, 10); +expectType(isTextLengthBetween3And10); + +const countReactive = store.useState('count'); +expectType(countReactive); + +const textReactive = store.useState('text'); +expectType(textReactive); + +const isTextLongerThan7Reactive = store.useState('textLongerThan', 7); +expectType(isTextLongerThan7Reactive); + +const isTextLengthBetween2And8Reactive = store.useState('textLengthBetween', 2, 8); +expectType(isTextLengthBetween2And8Reactive); + +// incorrect calls: + +// @ts-expect-error +store.select(); +// @ts-expect-error +store.select('count', 1); +// @ts-expect-error +store.select('textLongerThan'); +// @ts-expect-error +store.select('textLengthBetween', 1); +// @ts-expect-error +store.select('textLongerThan', 2, 3); + +// @ts-expect-error +store.useState(); +// @ts-expect-error +store.useState('count', 1); +// @ts-expect-error +store.useState('textLongerThan'); +// @ts-expect-error +store.useState('textLengthBetween', 1); +// @ts-expect-error +store.useState('textLongerThan', 2, 3); + +const unsubscribeFromCount = store.observe('count', (newValue, oldValue) => { + expectType(newValue); + expectType(oldValue); +}); +expectType<() => void, typeof unsubscribeFromCount>(unsubscribeFromCount); + +const unsubscribeFromSelector = store.observe( + (state) => state.text.length, + (newValue, oldValue) => { + expectType(newValue); + expectType(oldValue); + }, +); +expectType<() => void, typeof unsubscribeFromSelector>(unsubscribeFromSelector); + +// @ts-expect-error listener must match selector return type +store.observe( + (state) => state.text.length, + (newValue: string) => { + expectType(newValue); + }, +); diff --git a/packages/utils/src/store/ReactStore.test.tsx b/packages/utils/src/store/ReactStore.test.tsx index cf8e4811dbc..fc09dce6f73 100644 --- a/packages/utils/src/store/ReactStore.test.tsx +++ b/packages/utils/src/store/ReactStore.test.tsx @@ -1,12 +1,13 @@ import * as React from 'react'; import { expect } from 'chai'; -import { act, createRenderer } from '@mui/internal-test-utils'; +import { act, createRenderer, screen } from '@mui/internal-test-utils'; import { ReactStore } from './ReactStore'; import { useRefWithInit } from '../useRefWithInit'; +import { createSelector } from './createSelector'; type TestState = { value: number; label: string }; -function useStableStore(initial: State) { +function useStableStore(initial: State) { return useRefWithInit(() => new ReactStore(initial)).current; } @@ -189,7 +190,7 @@ describe('ReactStore', () => { expect(store.state.node).to.equal(undefined); }); - it('getElementSetter returns a stable callback that updates the store state', () => { + it('useStateSetter returns a stable callback that updates the store state', () => { type ElementState = { element: HTMLDivElement | null }; let store!: ReactStore; let forceUpdate!: React.Dispatch>; @@ -199,7 +200,7 @@ describe('ReactStore', () => { function Test() { store = useStableStore({ element: null }); - const setter = store.getElementSetter('element'); + const setter = store.useStateSetter('element'); lastSetter = setter; const [, setTick] = React.useState(0); forceUpdate = setTick; @@ -227,4 +228,286 @@ describe('ReactStore', () => { expect(store.state.element).to.equal(null); }); + + it('supports nested stores as state values', async () => { + type ParentState = { count: number }; + type ChildState = { count: number; parent?: ReactStore }; + + const parentSelectors = { count: (state: ParentState) => state.count }; + const childSelectors = { + count: (state: ChildState) => state.parent?.state.count ?? state.count, + parent: (state: ChildState) => state.parent, + }; + + const localCountSelector = createSelector((state: ChildState) => state.count); + + const parentStore = new ReactStore, typeof parentSelectors>( + { count: 0 }, + undefined, + parentSelectors, + ); + + const childStore = new ReactStore, typeof childSelectors>( + { count: 10 }, + undefined, + childSelectors, + ); + + let unsubscribeParentHandler: () => void; + const onParentUpdated = ( + newParent: ReactStore | undefined, + _: ReactStore | undefined, + store: ReactStore, + ) => { + if (!newParent) { + unsubscribeParentHandler?.(); + return; + } + + unsubscribeParentHandler = newParent.subscribe(() => { + store.notifyAll(); + }); + }; + + const onCountUpdated = ( + newCount: number, + _: number, + store: ReactStore, + ) => { + store.state.parent?.set('count', newCount); + }; + + childStore.observe('parent', onParentUpdated); + childStore.observe(localCountSelector, onCountUpdated); + + function Test() { + const count = childStore.useState('count'); + return {count}; + } + + render(); + const output = screen.getByTestId('output'); + + await act(async () => { + childStore.set('count', 5); + }); + expect(childStore.state.count).to.equal(5); + expect(output.textContent).to.equal('5'); + + await act(async () => { + childStore.set('parent', parentStore); + }); + expect(childStore.state.count).to.equal(5); + expect(childStore.select('count')).to.equal(0); + expect(output.textContent).to.equal('0'); + + await act(async () => { + childStore.set('count', 20); + }); + expect(childStore.state.count).to.equal(20); + expect(parentStore.state.count).to.equal(20); + expect(childStore.select('count')).to.equal(20); + expect(output.textContent).to.equal('20'); + + await act(async () => { + parentStore.set('count', 15); + }); + expect(parentStore.state.count).to.equal(15); + expect(childStore.state.count).to.equal(20); + expect(childStore.select('count')).to.equal(15); + expect(output.textContent).to.equal('15'); + }); + describe('observeSelector', () => { + type CounterState = { count: number; multiplier: number }; + const selectors = { + count: (state: CounterState) => state.count, + doubled: (state: CounterState) => state.count * 2, + multiplied: (state: CounterState) => state.count * state.multiplier, + }; + + it('accepts selector functions', () => { + const store = new ReactStore({ count: 0, multiplier: 1 }); + const calls: Array<{ newValue: boolean; oldValue: boolean }> = []; + + const unsubscribe = store.observe( + (state) => state.count > 1, + (newValue, oldValue) => { + calls.push({ newValue, oldValue }); + }, + ); + + expect(calls).to.have.lengthOf(1); + expect(calls[0]).to.deep.equal({ newValue: false, oldValue: false }); + + store.set('count', 2); + expect(calls).to.have.lengthOf(2); + expect(calls[1]).to.deep.equal({ newValue: true, oldValue: false }); + + store.set('count', 1); + expect(calls).to.have.lengthOf(3); + expect(calls[2]).to.deep.equal({ newValue: false, oldValue: true }); + + unsubscribe(); + + store.set('count', 3); + expect(calls).to.have.lengthOf(3); + }); + + it('calls listener immediately with current selector result on subscription', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const calls: Array<{ newValue: number; oldValue: number }> = []; + + store.observe('doubled', (newValue: number, oldValue: number) => { + calls.push({ newValue, oldValue }); + }); + + expect(calls).to.have.lengthOf(1); + expect(calls[0]).to.deep.equal({ newValue: 10, oldValue: 10 }); + }); + + it('calls listener when selector result changes', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const calls: Array<{ newValue: number; oldValue: number }> = []; + + store.observe('doubled', (newValue: number, oldValue: number) => { + calls.push({ newValue, oldValue }); + }); + + store.set('count', 10); + store.set('count', 7); + + expect(calls).to.have.lengthOf(3); + expect(calls[1]).to.deep.equal({ newValue: 20, oldValue: 10 }); + expect(calls[2]).to.deep.equal({ newValue: 14, oldValue: 20 }); + }); + + it('does not call listener when selector result is unchanged', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const calls: Array<{ newValue: number; oldValue: number }> = []; + + store.observe('doubled', (newValue: number, oldValue: number) => { + calls.push({ newValue, oldValue }); + }); + + store.set('multiplier', 5); + + expect(calls).to.have.lengthOf(1); // Only initial call + }); + + it('calls listener when any dependency of the selector changes', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const calls: Array<{ newValue: number; oldValue: number }> = []; + + store.observe('multiplied', (newValue: number, oldValue: number) => { + calls.push({ newValue, oldValue }); + }); + + store.set('count', 10); + store.set('multiplier', 2); + + expect(calls).to.have.lengthOf(3); + expect(calls[0]).to.deep.equal({ newValue: 15, oldValue: 15 }); + expect(calls[1]).to.deep.equal({ newValue: 30, oldValue: 15 }); + expect(calls[2]).to.deep.equal({ newValue: 20, oldValue: 30 }); + }); + + it('provides the store instance to the listener', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + let receivedStore!: ReactStore, typeof selectors>; + + store.observe('doubled', (_: number, __: number, storeArg) => { + receivedStore = storeArg; + }); + + expect(receivedStore).to.equal(store); + }); + + it('returns an unsubscribe function that stops observing', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const calls: Array<{ newValue: number; oldValue: number }> = []; + + const unsubscribe = store.observe('doubled', (newValue: number, oldValue: number) => { + calls.push({ newValue, oldValue }); + }); + + store.set('count', 10); + expect(calls).to.have.lengthOf(2); + + unsubscribe(); + + store.set('count', 15); + expect(calls).to.have.lengthOf(2); // No new calls after unsubscribe + }); + + it('supports multiple observers on the same selector', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const calls1: number[] = []; + const calls2: number[] = []; + + store.observe('doubled', (newValue: number) => { + calls1.push(newValue); + }); + + store.observe('doubled', (newValue: number) => { + calls2.push(newValue); + }); + + store.set('count', 10); + + expect(calls1).to.deep.equal([10, 20]); + expect(calls2).to.deep.equal([10, 20]); + }); + + it('supports observers on different selectors', () => { + const store = new ReactStore, typeof selectors>( + { count: 5, multiplier: 3 }, + undefined, + selectors, + ); + const doubledCalls: number[] = []; + const multipliedCalls: number[] = []; + + store.observe('doubled', (newValue: number) => { + doubledCalls.push(newValue); + }); + + store.observe('multiplied', (newValue: number) => { + multipliedCalls.push(newValue); + }); + + store.set('count', 10); + store.set('multiplier', 2); + + expect(doubledCalls).to.deep.equal([10, 20]); + expect(multipliedCalls).to.deep.equal([15, 30, 20]); + }); + }); }); diff --git a/packages/utils/src/store/ReactStore.ts b/packages/utils/src/store/ReactStore.ts index 007476eb730..0511bc73cd7 100644 --- a/packages/utils/src/store/ReactStore.ts +++ b/packages/utils/src/store/ReactStore.ts @@ -11,10 +11,17 @@ import { NOOP } from '../empty'; * A Store that supports controlled state keys, non-reactive values and provides utility methods for React. */ export class ReactStore< - State, + State extends object, Context = Record, - Selectors extends Record any> = Record, + Selectors extends Record> = Record, > extends Store { + /** + * Creates a new ReactStore instance. + * + * @param state Initial state of the store. + * @param context Non-reactive context values. + * @param selectors Optional selectors for use with `useState`. + */ constructor(state: State, context: Context = {} as Context, selectors?: Selectors) { super(state); this.context = context; @@ -152,9 +159,14 @@ export class ReactStore< public update(values: Partial): void { const newValues = { ...values }; for (const key in newValues) { + if (!Object.hasOwn(newValues, key)) { + continue; + } + if (this.controlledValues.get(key) === true) { // Ignore updates to controlled values delete newValues[key]; + continue; } } @@ -170,15 +182,29 @@ export class ReactStore< public setState(newState: State) { const newValues = { ...newState }; for (const key in newValues) { + if (!Object.hasOwn(newValues, key)) { + continue; + } + if (this.controlledValues.get(key) === true) { // Ignore updates to controlled values delete newValues[key]; + continue; } } super.setState({ ...this.state, ...newValues }); } + /** Gets the current value from the store using a selector with the provided key. + * + * @param key Key of the selector to use. + */ + public select = ((key: keyof Selectors, a1?: unknown, a2?: unknown, a3?: unknown) => { + const selector = this.selectors![key]; + return selector(this.state, a1, a2, a3); + }) as ReactStoreSelectorMethod; + /** * Returns a value from the store's state using a selector function. * Used to subscribe to specific parts of the state. @@ -186,15 +212,10 @@ export class ReactStore< * * @param key Key of the selector to use. */ - public useState(key: Key): ReturnType { - if (!this.selectors) { - throw new Error('Base UI: selectors are required to call useState.'); - } - return useStore>( - this, - this.selectors[key] as (state: State) => ReturnType, - ); - } + public useState = ((key: keyof Selectors, a1?: unknown, a2?: unknown, a3?: unknown) => { + const selector = this.selectors![key]; + return useStore(this, selector, a1, a2, a3); + }) as ReactStoreSelectorMethod; /** * Wraps a function with `useStableCallback` to ensure it has a stable reference @@ -214,16 +235,63 @@ export class ReactStore< /** * Returns a stable setter function for a specific key in the store's state. * It's commonly used to pass as a ref callback to React elements. + * * @param key Key of the state to set. */ - public getElementSetter(key: keyof State) { + public useStateSetter(key: keyof State) { return React.useCallback( - (element: Value) => { - this.set(key, element); + (value: Value) => { + this.set(key, value); }, [key], ); } + + /** + * Observes changes derived from the store's selectors and calls the listener when the selected value changes. + * + * @param key Key of the selector to observe. + * @param listener Listener function called when the selector result changes. + */ + public observe( + selector: Key, + listener: ( + newValue: ReturnType, + oldValue: ReturnType, + store: this, + ) => void, + ): () => void; + + public observe>( + selector: Selector, + listener: (newValue: ReturnType, oldValue: ReturnType, store: this) => void, + ): () => void; + + public observe( + selector: keyof Selectors | ObserveSelector, + listener: (newValue: any, oldValue: any, store: this) => void, + ) { + let selectFn: ObserveSelector; + + if (typeof selector === 'function') { + selectFn = selector; + } else { + selectFn = this.selectors![selector] as ObserveSelector; + } + + let prevValue = selectFn(this.state); + + listener(prevValue, prevValue, this); + + return this.subscribe((nextState) => { + const nextValue = selectFn(nextState); + if (!Object.is(prevValue, nextValue)) { + const oldValue = prevValue; + prevValue = nextValue; + listener(nextValue, oldValue, this); + } + }); + } } type MaybeCallable = (...args: any[]) => any; @@ -237,3 +305,20 @@ type ContextFunction = Extract = { [Key in keyof State]-?: undefined extends State[Key] ? Key : never; }[keyof State]; + +type ReactStoreSelectorMethod>> = < + Key extends keyof Selectors, +>( + key: Key, + ...args: SelectorArgs +) => ReturnType; + +type ObserveSelector = (state: State) => any; + +type SelectorFunction = (state: State, ...args: any[]) => any; + +type Tail = T extends readonly [any, ...infer Rest] ? Rest : []; + +type SelectorArgs = Selector extends (...params: infer Params) => any + ? Tail + : never; diff --git a/packages/utils/src/store/Store.ts b/packages/utils/src/store/Store.ts index 11f3d80939f..a01bb69047d 100644 --- a/packages/utils/src/store/Store.ts +++ b/packages/utils/src/store/Store.ts @@ -78,7 +78,7 @@ export class Store { public update(changes: Partial) { for (const key in changes) { if (!Object.is(this.state[key], changes[key])) { - this.setState({ ...this.state, ...changes }); + Store.prototype.setState.call(this, { ...this.state, ...changes }); return; } } @@ -92,7 +92,15 @@ export class Store { */ public set(key: keyof State, value: T) { if (!Object.is(this.state[key], value)) { - this.setState({ ...this.state, [key]: value }); + Store.prototype.setState.call(this, { ...this.state, [key]: value }); } } + + /** + * Gives the state a new reference and updates all registered listeners. + */ + public notifyAll() { + const newState = { ...this.state }; + Store.prototype.setState.call(this, newState); + } } diff --git a/packages/react/test/utils.ts b/packages/utils/src/testUtils.ts similarity index 100% rename from packages/react/test/utils.ts rename to packages/utils/src/testUtils.ts diff --git a/packages/utils/tsconfig.build.json b/packages/utils/tsconfig.build.json index 9b3b741e0bd..3651896426a 100644 --- a/packages/utils/tsconfig.build.json +++ b/packages/utils/tsconfig.build.json @@ -10,7 +10,8 @@ "moduleResolution": "bundler", "noEmit": false, "rootDir": "./src", - "outDir": "build/esm" + "outDir": "build/esm", + "lib": ["ES2022", "DOM"] }, "include": ["src/**/*.ts*", "src/**/*.tsx"], "exclude": ["src/**/*.spec.ts*", "src/**/*.test.ts*"]