From c74b2dba6e92da3b07f75b59e5cd32d1e802a6ea Mon Sep 17 00:00:00 2001 From: Noeri Huisman <8823461+mrxz@users.noreply.github.com> Date: Mon, 20 Oct 2025 13:58:04 +0200 Subject: [PATCH] Restructured Splat setup aimed at singular Splat object --- examples/debug-color/index.html | 49 +- examples/depth-of-field/index.html | 62 +- examples/dynamic-lighting/index.html | 113 ++- examples/envmap/index.html | 61 +- examples/glsl/index.html | 108 +- examples/hello-world/carousel.html | 5 +- examples/hello-world/index.html | 5 +- examples/interactivity/index.html | 10 +- src/controls.ts => examples/js/controls.js | 133 +-- examples/js/preloader.js | 11 +- examples/multiple-splats/index.html | 10 +- examples/multiple-viewpoints/index.html | 34 +- examples/particle-animation/index.html | 24 +- examples/particle-simulation/index.html | 35 +- examples/procedural-splats/index.html | 309 +++--- examples/raycasting/index.html | 244 ++--- examples/sogs/index.html | 5 +- examples/splat-reveal-effects/index.html | 278 +++-- examples/stochastic/index.html | 16 +- examples/viewer/index.html | 29 +- examples/webxr/index.html | 91 +- src/BatchedSplat.ts | 251 +++++ src/PackedSplats.ts | 749 -------------- src/Readback.ts | 338 ------ src/RgbaArray.ts | 283 ------ src/SparkRenderer.ts | 1073 -------------------- src/SparkViewpoint.ts | 878 ---------------- src/Splat.ts | 701 +++++++++++++ src/SplatAccumulator.ts | 120 --- src/SplatEdit.ts | 829 --------------- src/SplatGenerator.ts | 233 ----- src/SplatGeometry.ts | 105 +- src/SplatLoader.ts | 445 ++------ src/SplatMesh.ts | 974 ------------------ src/SplatSkinning.ts | 298 ------ src/SplatSorter.ts | 425 ++++++++ src/SplatUtils.ts | 116 +++ src/SplatWorker.ts | 155 +++ src/antisplat.ts | 125 --- src/defines.ts | 21 + src/dyno.ts | 16 - src/dyno/base.ts | 575 ----------- src/dyno/control.ts | 22 - src/dyno/convert.ts | 451 -------- src/dyno/logic.ts | 434 -------- src/dyno/math.ts | 534 ---------- src/dyno/mathTypes.ts | 717 ------------- src/dyno/output.ts | 78 -- src/dyno/program.ts | 117 --- src/dyno/splats.ts | 594 ----------- src/dyno/texture.ts | 239 ----- src/dyno/transform.ts | 155 --- src/dyno/trig.ts | 182 ---- src/dyno/types.ts | 420 -------- src/dyno/uniforms.ts | 826 --------------- src/dyno/util.ts | 441 -------- src/dyno/value.ts | 289 ------ src/dyno/vecmat.ts | 835 --------------- src/encoding/ExtendedSplats.ts | 464 +++++++++ src/encoding/PackedSplats.ts | 549 ++++++++++ src/encoding/encoder.ts | 122 +++ src/formats/antisplat.ts | 52 + src/formats/ksplat.ts | 303 ++++++ src/{ => formats}/pcsogs.ts | 179 ++-- src/{ => formats}/ply.ts | 145 +-- src/formats/spz.ts | 433 ++++++++ src/generators.ts | 2 - src/generators/snow.ts | 276 ----- src/generators/static.ts | 117 --- src/hands.ts | 472 --------- src/index.ts | 98 +- src/ksplat.ts | 636 ------------ src/modifiers.ts | 2 - src/modifiers/depthColor.ts | 60 -- src/modifiers/normalColor.ts | 46 - src/procedural.ts | 566 +++++++++++ src/raycast.ts | 84 ++ src/shaders.ts | 27 +- src/shaders/computeUvec4.glsl | 36 - src/shaders/computeVec4.glsl | 36 - src/shaders/extendedSplat.glsl | 131 +++ src/shaders/identityVertex.glsl | 7 + src/shaders/packedSplat.glsl | 220 ++++ src/shaders/splatDefines.glsl | 333 +----- src/shaders/splatDistanceFragment.glsl | 72 ++ src/shaders/splatFragment.glsl | 52 +- src/shaders/splatVertex.glsl | 162 +-- src/splatConstructors.ts | 419 -------- src/splatWorker.ts | 128 --- src/spz.ts | 833 --------------- src/transcode.ts | 331 ++++++ src/utils.ts | 1027 +------------------ src/vrButton.ts | 164 --- src/worker.ts | 681 ------------- src/worker/sort.ts | 242 +++++ src/worker/worker.ts | 222 ++++ tsconfig.json | 4 +- 97 files changed, 6745 insertions(+), 19664 deletions(-) rename src/controls.ts => examples/js/controls.js (82%) create mode 100644 src/BatchedSplat.ts delete mode 100644 src/PackedSplats.ts delete mode 100644 src/Readback.ts delete mode 100644 src/RgbaArray.ts delete mode 100644 src/SparkRenderer.ts delete mode 100644 src/SparkViewpoint.ts create mode 100644 src/Splat.ts delete mode 100644 src/SplatAccumulator.ts delete mode 100644 src/SplatEdit.ts delete mode 100644 src/SplatGenerator.ts delete mode 100644 src/SplatMesh.ts delete mode 100644 src/SplatSkinning.ts create mode 100644 src/SplatSorter.ts create mode 100644 src/SplatUtils.ts create mode 100644 src/SplatWorker.ts delete mode 100644 src/antisplat.ts delete mode 100644 src/dyno.ts delete mode 100644 src/dyno/base.ts delete mode 100644 src/dyno/control.ts delete mode 100644 src/dyno/convert.ts delete mode 100644 src/dyno/logic.ts delete mode 100644 src/dyno/math.ts delete mode 100644 src/dyno/mathTypes.ts delete mode 100644 src/dyno/output.ts delete mode 100644 src/dyno/program.ts delete mode 100644 src/dyno/splats.ts delete mode 100644 src/dyno/texture.ts delete mode 100644 src/dyno/transform.ts delete mode 100644 src/dyno/trig.ts delete mode 100644 src/dyno/types.ts delete mode 100644 src/dyno/uniforms.ts delete mode 100644 src/dyno/util.ts delete mode 100644 src/dyno/value.ts delete mode 100644 src/dyno/vecmat.ts create mode 100644 src/encoding/ExtendedSplats.ts create mode 100644 src/encoding/PackedSplats.ts create mode 100644 src/encoding/encoder.ts create mode 100644 src/formats/antisplat.ts create mode 100644 src/formats/ksplat.ts rename src/{ => formats}/pcsogs.ts (74%) rename src/{ => formats}/ply.ts (91%) create mode 100644 src/formats/spz.ts delete mode 100644 src/generators.ts delete mode 100644 src/generators/snow.ts delete mode 100644 src/generators/static.ts delete mode 100644 src/hands.ts delete mode 100644 src/ksplat.ts delete mode 100644 src/modifiers.ts delete mode 100644 src/modifiers/depthColor.ts delete mode 100644 src/modifiers/normalColor.ts create mode 100644 src/procedural.ts create mode 100644 src/raycast.ts delete mode 100644 src/shaders/computeUvec4.glsl delete mode 100644 src/shaders/computeVec4.glsl create mode 100644 src/shaders/extendedSplat.glsl create mode 100644 src/shaders/identityVertex.glsl create mode 100644 src/shaders/packedSplat.glsl create mode 100644 src/shaders/splatDistanceFragment.glsl delete mode 100644 src/splatConstructors.ts delete mode 100644 src/splatWorker.ts delete mode 100644 src/spz.ts create mode 100644 src/transcode.ts delete mode 100644 src/vrButton.ts delete mode 100644 src/worker.ts create mode 100644 src/worker/sort.ts create mode 100644 src/worker/worker.ts diff --git a/examples/debug-color/index.html b/examples/debug-color/index.html index 4707b38..34328f9 100644 --- a/examples/debug-color/index.html +++ b/examples/debug-color/index.html @@ -26,7 +26,7 @@ diff --git a/examples/particle-animation/index.html b/examples/particle-animation/index.html index 5e42467..5e3452a 100644 --- a/examples/particle-animation/index.html +++ b/examples/particle-animation/index.html @@ -32,14 +32,11 @@ - - - - + + + + + + + Spark • Procedural Splats + + + + + + + + + diff --git a/examples/raycasting/index.html b/examples/raycasting/index.html index b1426c9..542a517 100644 --- a/examples/raycasting/index.html +++ b/examples/raycasting/index.html @@ -1,117 +1,127 @@ - - - - - - - Spark • Raycasting - - - - -
Click to select
- - - - - + + + + + + + Spark • Raycasting + + + + +
Click to select
+ + + + + diff --git a/examples/sogs/index.html b/examples/sogs/index.html index 17305f7..bc7648f 100644 --- a/examples/sogs/index.html +++ b/examples/sogs/index.html @@ -26,7 +26,7 @@ import * as THREE from "three"; import { OrbitControls } from 'three/addons/controls/OrbitControls.js'; import { Sky } from 'three/addons/objects/Sky.js'; - import { SplatMesh } from "@sparkjsdev/spark"; + import { Splat, SplatLoader } from "@sparkjsdev/spark"; import { getAssetFileURL } from "/examples/js/get-asset-url.js"; const scene = new THREE.Scene(); @@ -46,7 +46,8 @@ } const splatURL = await getAssetFileURL("sutro.zip"); - const sutroTower = new SplatMesh({ url: splatURL }); + const loader = new SplatLoader(); + const sutroTower = await loader.loadAsync(splatURL); sutroTower.quaternion.set(1, 0, 0, 0); scene.add(sutroTower); diff --git a/examples/splat-reveal-effects/index.html b/examples/splat-reveal-effects/index.html index ea12735..61e2c49 100644 --- a/examples/splat-reveal-effects/index.html +++ b/examples/splat-reveal-effects/index.html @@ -27,8 +27,9 @@ diff --git a/src/BatchedSplat.ts b/src/BatchedSplat.ts new file mode 100644 index 0000000..aa6cdb8 --- /dev/null +++ b/src/BatchedSplat.ts @@ -0,0 +1,251 @@ +import * as THREE from "three"; +import { + type IterableSplatData, + type SortContext, + Splat, + type SplatData, +} from "./Splat"; +import { CpuSplatSorter, type SplatOrdering } from "./SplatSorter"; +import { isIterableSplatData } from "./SplatUtils"; +import type { TransformRange } from "./defines"; +import { DefaultSplatEncoding } from "./encoding/encoder"; + +/** + * Specialized Splat class for combining multiple splats in one draw call. + * All splats are sorted allowing for overlapping splats, while each instance + * retains its own transform matrix. + */ +export class BatchedSplat extends Splat { + readonly maxInstanceCount: number; + private readonly batchedSplatData: BatchedSplatData; + + private matricesArray: Float32Array; + private matricesTexture: THREE.DataTexture; + + constructor(maxInstanceCount: number) { + const batchingTextureUniform: THREE.IUniform = { + value: null as THREE.Texture | null, + }; + const batchedSplatData = new BatchedSplatData(batchingTextureUniform); + super(batchedSplatData, { sorter: new CpuSplatSorter() }); + this.batchedSplatData = batchedSplatData; + + this.maxInstanceCount = maxInstanceCount; + let size = Math.sqrt(this.maxInstanceCount * 4); // 4 pixels needed for 1 matrix + size = Math.ceil(size / 4) * 4; + size = Math.max(size, 4); + + this.matricesArray = new Float32Array(size * size * 4); // 4 floats per RGBA pixel + this.matricesTexture = new THREE.DataTexture( + this.matricesArray, + size, + size, + THREE.RGBAFormat, + THREE.FloatType, + ); + batchingTextureUniform.value = this.matricesTexture; + + // Disable frustum culling as the transform of BatchedSplat is ignored + // in favour of the individual instance transform matrices. + this.frustumCulled = false; + } + + addSplat(splat: Splat) { + const splatData = splat.splatData; + if (!isIterableSplatData(splatData)) { + throw new Error( + "Splat can't be added to BatchedSplat as its splat data is not iterable", + ); + } + + this.addSplatData(splatData); + const index = this.batchedSplatData.instanceCount - 1; + splat.updateMatrixWorld(); + this.setMatrixAt(index, splat.matrixWorld); + } + + addSplatData(splatData: IterableSplatData) { + this.batchedSplatData.addSplatData(splatData); + this.batchedSplatData.setupMaterial(this.material); + this.needsUpdate = true; + } + + removeSplatData(splatData: IterableSplatData) { + this.batchedSplatData.removeSplatData(splatData); + this.batchedSplatData.setupMaterial(this.material); + this.needsUpdate = true; + } + + setMatrixAt(instanceId: number, matrix: THREE.Matrix4) { + matrix.toArray(this.matricesArray, instanceId * 16); + this.matricesTexture.needsUpdate = true; + this.needsUpdate = true; + return this; + } + + getTransformRanges(): Array { + const result: Array = []; + + let start = 0; + for (let i = 0; i < this.batchedSplatData.instanceCount; i++) { + const numSplats = this.batchedSplatData.sources[i].numSplats; + result.push({ + start, + end: start + numSplats, + matrix: [...this.matricesArray.slice(i * 16, (i + 1) * 16)], + }); + start += numSplats; + } + + return result; + } + + protected onSortComplete(context: SortContext, result: SplatOrdering) { + // Include object index into ordering array + for (let i = 0; i < result.activeSplats; i++) { + const splatIndex = result.ordering[i]; + const objectIndex = this.batchedSplatData.getInstanceIndexFor(splatIndex); + result.ordering[i] = splatIndex | (objectIndex << 26); + } + super.onSortComplete(context, result); + } + + dispose(): void { + super.dispose(); + this.batchedSplatData.dispose(); + } +} + +/** + * SplatData implementation that allows combining multiple individual + * splat data sources into one for batched draw calls. + */ +class BatchedSplatData implements SplatData { + private splatData: SplatData; + sources: Array = []; + private batchingTextureUniform: THREE.IUniform; + + constructor(batchingTextureUniform: THREE.IUniform) { + this.splatData = this.recreate(); + this.batchingTextureUniform = batchingTextureUniform; + } + + private recreate(): SplatData { + const numSh = this.sources[0]?.numSh ?? 0; + const numSplats = this.sources.reduce( + (sum, source) => sum + source.numSplats, + 0, + ); + + const splatEncoder = DefaultSplatEncoding.createSplatEncoder(); + splatEncoder.allocate(numSplats, numSh); + + let splatIndex = 0; + for (const source of this.sources) { + source.iterateSplats( + ( + _, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + sh, + ) => { + splatEncoder.setSplat( + splatIndex, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ); + if (sh) { + splatEncoder.setSplatSh(splatIndex, sh); + } + splatIndex++; + }, + ); + } + + this.splatData?.dispose(); + this.splatData = splatEncoder.close(); + + return this.splatData; + } + + get instanceCount() { + return this.sources.length; + } + + getInstanceIndexFor(splatIndex: number): number { + let instanceIndex = 0; + let instanceEnd = this.sources[instanceIndex].numSplats; + while (splatIndex >= instanceEnd) { + instanceIndex++; + instanceEnd += this.sources[instanceIndex].numSplats; + } + return instanceIndex; + } + + addSplatData(source: IterableSplatData) { + this.sources.push(source); + this.recreate(); + } + + removeSplatData(source: IterableSplatData) { + if (this.sources.indexOf(source)) { + this.sources.splice(this.sources.indexOf(source), 1); + this.recreate(); + } + } + + get maxSplats() { + return this.splatData.maxSplats; + } + + get numSplats() { + return this.splatData.numSplats; + } + + get numSh() { + return this.splatData.numSh; + } + + setupMaterial(material: THREE.ShaderMaterial) { + this.splatData.setupMaterial(material); + if (!("batchingTexture" in material.uniforms)) { + material.uniforms.batchingTexture = this.batchingTextureUniform; + } + material.defines.USE_BATCHING = true; + } + + iterateCenters( + callback: (index: number, x: number, y: number, z: number) => void, + ) { + this.splatData.iterateCenters(callback); + } + + dispose(): void { + // Only dispose the combined splat. The other splat sources aren't owned by this instance. + this.splatData.dispose(); + } +} diff --git a/src/PackedSplats.ts b/src/PackedSplats.ts deleted file mode 100644 index 327d318..0000000 --- a/src/PackedSplats.ts +++ /dev/null @@ -1,749 +0,0 @@ -import * as THREE from "three"; -import { FullScreenQuad } from "three/addons/postprocessing/Pass.js"; - -import type { GsplatGenerator } from "./SplatGenerator"; -import { type SplatFileType, SplatLoader, unpackSplats } from "./SplatLoader"; -import { - LN_SCALE_MAX, - LN_SCALE_MIN, - SPLAT_TEX_HEIGHT, - SPLAT_TEX_WIDTH, -} from "./defines"; -import { - DynoProgram, - DynoProgramTemplate, - DynoUniform, - DynoVec2, - DynoVec4, - dynoBlock, - outputPackedSplat, -} from "./dyno"; -import { TPackedSplats, definePackedSplats } from "./dyno/splats"; -import computeUvec4Template from "./shaders/computeUvec4.glsl"; -import { getTextureSize, setPackedSplat, unpackSplat } from "./utils"; - -export type SplatEncoding = { - rgbMin?: number; - rgbMax?: number; - lnScaleMin?: number; - lnScaleMax?: number; - sh1Min?: number; - sh1Max?: number; - sh2Min?: number; - sh2Max?: number; - sh3Min?: number; - sh3Max?: number; -}; - -export const DEFAULT_SPLAT_ENCODING: SplatEncoding = { - rgbMin: 0, - rgbMax: 1, - lnScaleMin: LN_SCALE_MIN, - lnScaleMax: LN_SCALE_MAX, - sh1Min: -1, - sh1Max: 1, - sh2Min: -1, - sh2Max: 1, - sh3Min: -1, - sh3Max: 1, -}; - -// Initialize a PackedSplats collection from source data via -// url, fileBytes, or packedArray. Creates an empty array if none are set, -// and splat data can be constructed using pushSplat()/setSplat(). The maximum -// splat size allocation will grow automatically, starting from maxSplats. -export type PackedSplatsOptions = { - // URL to fetch a Gaussian splat file from (supports .ply, .splat, .ksplat, - // .spz formats). (default: undefined) - url?: string; - // Raw bytes of a Gaussian splat file to decode directly instead of fetching - // from URL. (default: undefined) - fileBytes?: Uint8Array | ArrayBuffer; - // Override the file type detection for formats that can't be reliably - // auto-detected (.splat, .ksplat). (default: undefined auto-detects other - // formats from file contents) - fileType?: SplatFileType; - // File name to use for type detection. (default: undefined) - fileName?: string; - // Reserve space for at least this many splats when constructing the collection - // initially. The array will automatically resize past maxSplats so setting it is - // an optional optimization. (default: 0) - maxSplats?: number; - // Use provided packed data array, where each 4 consecutive uint32 values - // encode one "packed" Gsplat. (default: undefined) - packedArray?: Uint32Array; - // Override number of splats in packed array to use only a subset. - // (default: length of packed array / 4) - numSplats?: number; - // Callback function to programmatically create splats at initialization. - // (default: undefined) - construct?: (splats: PackedSplats) => Promise | void; - // Additional splat data, such as spherical harmonics components (sh1, sh2, sh3). (default: {}) - extra?: Record; - // Override the default splat encoding ranges for the PackedSplats. - // (default: undefined) - splatEncoding?: SplatEncoding; -}; - -// A PackedSplats is a collection of Gaussian splats, packed into a format that -// takes exactly 16 bytes per Gsplat to maximize memory and cache efficiency. -// The center xyz coordinates are encoded as float16 (3 x 2 bytes), scale xyz -// as 3 x uint8 that encode a log scale from e^-12 to e^9, rgba as 4 x uint8, -// and quaternion encoded via axis+angle using 2 x uint8 for octahedral encoding -// of the axis direction and a uint8 to encode rotation amount from 0..Pi. - -export class PackedSplats { - maxSplats = 0; - numSplats = 0; - packedArray: Uint32Array | null = null; - extra: Record; - splatEncoding?: SplatEncoding; - - initialized: Promise; - isInitialized = false; - - // Either target or source will be non-null, depending on whether the PackedSplats - // is being used as a data source or generated to. - target: THREE.WebGLArrayRenderTarget | null = null; - source: THREE.DataArrayTexture | null = null; - // Set to true if source packedArray is updated to have it upload to GPU - needsUpdate = true; - - // A PackedSplats can be used in a dyno graph using the below property dyno: - // const gsplat = dyno.readPackedSplats(this.dyno, dynoIndex); - dyno: DynoUniform; - dynoRgbMinMaxLnScaleMinMax: DynoUniform<"vec4", "rgbMinMaxLnScaleMinMax">; - dynoSh1MinMax: DynoUniform<"vec2", "sh1MinMax">; - dynoSh2MinMax: DynoUniform<"vec2", "sh2MinMax">; - dynoSh3MinMax: DynoUniform<"vec2", "sh3MinMax">; - - constructor(options: PackedSplatsOptions = {}) { - this.extra = {}; - this.dyno = new DynoPackedSplats({ packedSplats: this }); - this.dynoRgbMinMaxLnScaleMinMax = new DynoVec4({ - key: "rgbMinMaxLnScaleMinMax", - value: new THREE.Vector4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX), - update: (value) => { - value.set( - this.splatEncoding?.rgbMin ?? 0.0, - this.splatEncoding?.rgbMax ?? 1.0, - this.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, - this.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, - ); - return value; - }, - }); - this.dynoSh1MinMax = new DynoVec2({ - key: "sh1MinMax", - value: new THREE.Vector2(-1, 1), - update: (value) => { - value.set( - this.splatEncoding?.sh1Min ?? -1, - this.splatEncoding?.sh1Max ?? 1, - ); - return value; - }, - }); - this.dynoSh2MinMax = new DynoVec2({ - key: "sh2MinMax", - value: new THREE.Vector2(-1, 1), - update: (value) => { - value.set( - this.splatEncoding?.sh2Min ?? -1, - this.splatEncoding?.sh2Max ?? 1, - ); - return value; - }, - }); - this.dynoSh3MinMax = new DynoVec2({ - key: "sh3MinMax", - value: new THREE.Vector2(-1, 1), - update: (value) => { - value.set( - this.splatEncoding?.sh3Min ?? -1, - this.splatEncoding?.sh3Max ?? 1, - ); - return value; - }, - }); - - // The following line will be overridden by reinitialize() - this.initialized = Promise.resolve(this); - this.reinitialize(options); - } - - reinitialize(options: PackedSplatsOptions) { - this.isInitialized = false; - - this.extra = {}; - this.splatEncoding = options.splatEncoding; - - if (options.url || options.fileBytes || options.construct) { - // We need to initialize asynchronously given the options - this.initialized = this.asyncInitialize(options).then(() => { - this.isInitialized = true; - return this; - }); - } else { - this.initialize(options); - this.isInitialized = true; - this.initialized = Promise.resolve(this); - } - } - - initialize(options: PackedSplatsOptions) { - if (options.packedArray) { - this.packedArray = options.packedArray; - // Calculate number of horizontal texture rows that could fit in array. - // A properly initialized packedArray should already take into account the - // width and height of the texture and be rounded up with padding. - this.maxSplats = Math.floor(this.packedArray.length / 4); - this.maxSplats = - Math.floor(this.maxSplats / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; - this.numSplats = Math.min( - this.maxSplats, - options.numSplats ?? Number.POSITIVE_INFINITY, - ); - } else { - this.maxSplats = options.maxSplats ?? 0; - this.numSplats = 0; - } - this.extra = options.extra ?? {}; - } - - async asyncInitialize(options: PackedSplatsOptions) { - const { url, fileBytes, construct } = options; - if (url) { - const loader = new SplatLoader(); - loader.packedSplats = this; - await loader.loadAsync(url); - } else if (fileBytes) { - const unpacked = await unpackSplats({ - input: fileBytes, - fileType: options.fileType, - pathOrUrl: options.fileName ?? url, - splatEncoding: options.splatEncoding ?? DEFAULT_SPLAT_ENCODING, - }); - this.initialize(unpacked); - } - - if (construct) { - const maybePromise = construct(this); - // If construct returns a promise, wait for it to complete - if (maybePromise instanceof Promise) { - await maybePromise; - } - } - } - - // Call this when you are finished with the PackedSplats and want to free - // any buffers it holds. - dispose() { - if (this.target) { - this.target.dispose(); - this.target = null; - } - if (this.source) { - this.source.dispose(); - this.source = null; - } - } - - // Ensures that this.packedArray can fit numSplats Gsplats. If it's too small, - // resize exponentially and copy over the original data. - // - // Typically you don't need to call this, because calling this.setSplat(index, ...) - // and this.pushSplat(...) will automatically call ensureSplats() so we have - // enough splats. - ensureSplats(numSplats: number): Uint32Array { - const targetSize = - numSplats <= this.maxSplats - ? this.maxSplats - : // Grow exponentially to avoid frequent reallocations - Math.max(numSplats, 2 * this.maxSplats); - const currentSize = !this.packedArray ? 0 : this.packedArray.length / 4; - - if (!this.packedArray || targetSize > currentSize) { - this.maxSplats = getTextureSize(targetSize).maxSplats; - const newArray = new Uint32Array(this.maxSplats * 4); - if (this.packedArray) { - // Copy over existing data - newArray.set(this.packedArray); - } - this.packedArray = newArray; - } - return this.packedArray; - } - - // Ensure the extra array for the given level is large enough to hold numSplats - ensureSplatsSh(level: number, numSplats: number): Uint32Array { - let wordsPerSplat: number; - let key: string; - if (level === 0) { - return this.ensureSplats(numSplats); - } - if (level === 1) { - // 3 x 3 uint7 = 63 bits = 2 uint32 - wordsPerSplat = 2; - key = "sh1"; - } else if (level === 2) { - // 5 x 3 uint8 = 120 bits = 4 uint32 - wordsPerSplat = 4; - key = "sh2"; - } else if (level === 3) { - // 7 x 3 uint6 = 126 bits = 4 uint32 - wordsPerSplat = 4; - key = "sh3"; - } else { - throw new Error(`Invalid level: ${level}`); - } - - // Figure out our current and desired maxSplats - let maxSplats: number = !this.extra[key] - ? 0 - : (this.extra[key] as Uint32Array).length / wordsPerSplat; - const targetSize = - numSplats <= maxSplats ? maxSplats : Math.max(numSplats, 2 * maxSplats); - - if (!this.extra[key] || targetSize > maxSplats) { - // Reallocate the array - maxSplats = getTextureSize(targetSize).maxSplats; - const newArray = new Uint32Array(maxSplats * wordsPerSplat); - if (this.extra[key]) { - // Copy over existing data - newArray.set(this.extra[key] as Uint32Array); - } - this.extra[key] = newArray; - } - return this.extra[key] as Uint32Array; - } - - // Unpack the 16-byte Gsplat data at index into the Three.js components - // center: THREE.Vector3, scales: THREE.Vector3, quaternion: THREE.Quaternion, - // opacity: number 0..1, color: THREE.Color 0..1. - getSplat(index: number): { - center: THREE.Vector3; - scales: THREE.Vector3; - quaternion: THREE.Quaternion; - opacity: number; - color: THREE.Color; - } { - if (!this.packedArray || index >= this.numSplats) { - throw new Error("Invalid index"); - } - return unpackSplat(this.packedArray, index, this.splatEncoding); - } - - // Set all PackedSplat components at index with the provided Gsplat attributes - // (can be the same objects returned by getSplat). Ensures there is capacity - // for at least index+1 Gsplats. - setSplat( - index: number, - center: THREE.Vector3, - scales: THREE.Vector3, - quaternion: THREE.Quaternion, - opacity: number, - color: THREE.Color, - ) { - const packedSplats = this.ensureSplats(index + 1); - setPackedSplat( - packedSplats, - index, - center.x, - center.y, - center.z, - scales.x, - scales.y, - scales.z, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - opacity, - color.r, - color.g, - color.b, - ); - this.numSplats = Math.max(this.numSplats, index + 1); - } - - // Effectively calls this.setSplat(this.numSplats++, center, ...), useful on - // construction where you just want to iterate and create a collection of Gsplats. - pushSplat( - center: THREE.Vector3, - scales: THREE.Vector3, - quaternion: THREE.Quaternion, - opacity: number, - color: THREE.Color, - ) { - const packedSplats = this.ensureSplats(this.numSplats + 1); - setPackedSplat( - packedSplats, - this.numSplats, - center.x, - center.y, - center.z, - scales.x, - scales.y, - scales.z, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - opacity, - color.r, - color.g, - color.b, - ); - ++this.numSplats; - } - - // Iterate over Gsplats index 0..=(this.numSplats-1), unpack each Gsplat - // and invoke the callback function with the Gsplat attributes. - forEachSplat( - callback: ( - index: number, - center: THREE.Vector3, - scales: THREE.Vector3, - quaternion: THREE.Quaternion, - opacity: number, - color: THREE.Color, - ) => void, - ) { - if (!this.packedArray || !this.numSplats) { - return; - } - for (let i = 0; i < this.numSplats; ++i) { - const unpacked = unpackSplat(this.packedArray, i, this.splatEncoding); - callback( - i, - unpacked.center, - unpacked.scales, - unpacked.quaternion, - unpacked.opacity, - unpacked.color, - ); - } - } - - // Ensures our PackedSplats.target render target has enough space to generate - // maxSplats total Gsplats, and reallocate if not large enough. - ensureGenerate(maxSplats: number): boolean { - if (this.target && (maxSplats ?? 1) <= this.maxSplats) { - return false; - } - this.dispose(); - - const textureSize = getTextureSize(maxSplats ?? 1); - const { width, height, depth } = textureSize; - this.maxSplats = textureSize.maxSplats; - - // The packed Gsplats are stored in a 2D array texture of max size - // 2048 x 2048 x 2048, one RGBA32UI pixel = 4 uint32 = one Gsplat - this.target = new THREE.WebGLArrayRenderTarget(width, height, depth, { - depthBuffer: false, - stencilBuffer: false, - generateMipmaps: false, - magFilter: THREE.NearestFilter, - minFilter: THREE.NearestFilter, - }); - this.target.texture.format = THREE.RGBAIntegerFormat; - this.target.texture.type = THREE.UnsignedIntType; - this.target.texture.internalFormat = "RGBA32UI"; - this.target.scissorTest = true; - return true; - } - - // Given an array of splatCounts (.numSplats for each - // SplatGenerator/SplatMesh in the scene), compute a - // "mapping layout" in the composite array of generated outputs. - generateMapping(splatCounts: number[]): { - maxSplats: number; - mapping: { base: number; count: number }[]; - } { - let maxSplats = 0; - const mapping = splatCounts.map((numSplats) => { - const base = maxSplats; - // Generation happens in horizontal row chunks, so round up to full width - const rounded = Math.ceil(numSplats / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; - maxSplats += rounded; - return { base, count: numSplats }; - }); - return { maxSplats, mapping }; - } - - // Returns a THREE.DataArrayTexture representing the PackedSplats content as - // a Uint32x4 data array texture (2048 x 2048 x depth in size) - getTexture(): THREE.DataArrayTexture { - if (this.target) { - // Return the render target's texture - return this.target.texture; - } - if (this.source || this.packedArray) { - // Update source texture if needed and return - const source = this.maybeUpdateSource(); - return source; - } - - return PackedSplats.getEmpty(); - } - - // Check if source texture needs to be created/updated - private maybeUpdateSource(): THREE.DataArrayTexture { - if (!this.packedArray) { - throw new Error("No packed splats"); - } - - if (this.needsUpdate || !this.source) { - this.needsUpdate = false; - - if (this.source) { - const { width, height, depth } = this.source.image; - if (this.maxSplats !== width * height * depth) { - // The existing source texture isn't the right size, so dispose it - this.source.dispose(); - this.source = null; - } - } - if (!this.source) { - // Allocate a new source texture of the right size - const { width, height, depth } = getTextureSize(this.maxSplats); - this.source = new THREE.DataArrayTexture( - this.packedArray, - width, - height, - depth, - ); - this.source.format = THREE.RGBAIntegerFormat; - this.source.type = THREE.UnsignedIntType; - this.source.internalFormat = "RGBA32UI"; - this.source.needsUpdate = true; - } else if (this.packedArray.buffer !== this.source.image.data.buffer) { - // The source texture is the right size, update the data - this.source.image.data = new Uint8Array(this.packedArray.buffer); - } - // Indicate to Three.js that the source texture needs to be uploaded to the GPU - this.source.needsUpdate = true; - } - return this.source; - } - - private static emptySource: THREE.DataArrayTexture | null = null; - - // Can be used where you need an uninitialized THREE.DataArrayTexture like - // a uniform you will update with the result of this.getTexture() later. - static getEmpty(): THREE.DataArrayTexture { - if (!PackedSplats.emptySource) { - const { width, height, depth, maxSplats } = getTextureSize(1); - const emptyArray = new Uint32Array(maxSplats * 4); - PackedSplats.emptySource = new THREE.DataArrayTexture( - emptyArray, - width, - height, - depth, - ); - PackedSplats.emptySource.format = THREE.RGBAIntegerFormat; - PackedSplats.emptySource.type = THREE.UnsignedIntType; - PackedSplats.emptySource.internalFormat = "RGBA32UI"; - PackedSplats.emptySource.needsUpdate = true; - } - return PackedSplats.emptySource; - } - - // Get a program and THREE.RawShaderMaterial for a given GsplatGenerator, - // generating it if necessary and caching the result. - prepareProgramMaterial(generator: GsplatGenerator): { - program: DynoProgram; - material: THREE.RawShaderMaterial; - } { - let program = PackedSplats.generatorProgram.get(generator); - if (!program) { - // A Gsplat needs to be turned into a packed uvec4 for the dyno graph - const graph = dynoBlock( - { index: "int" }, - { output: "uvec4" }, - ({ index }) => { - generator.inputs.index = index; - const gsplat = generator.outputs.gsplat; - const output = outputPackedSplat( - gsplat, - this.dynoRgbMinMaxLnScaleMinMax, - ); - return { output }; - }, - ); - if (!PackedSplats.programTemplate) { - PackedSplats.programTemplate = new DynoProgramTemplate( - computeUvec4Template, - ); - } - // Create a program from the template and graph - program = new DynoProgram({ - graph, - inputs: { index: "index" }, - outputs: { output: "target" }, - template: PackedSplats.programTemplate, - }); - Object.assign(program.uniforms, { - targetLayer: { value: 0 }, - targetBase: { value: 0 }, - targetCount: { value: 0 }, - }); - PackedSplats.generatorProgram.set(generator, program); - } - - // Prepare and update our material we'll use to render the Gsplats - const material = program.prepareMaterial(); - PackedSplats.fullScreenQuad.material = material; - return { program, material }; - } - - private saveRenderState(renderer: THREE.WebGLRenderer) { - return { - xrEnabled: renderer.xr.enabled, - autoClear: renderer.autoClear, - }; - } - - private resetRenderState( - renderer: THREE.WebGLRenderer, - state: { - xrEnabled: boolean; - autoClear: boolean; - }, - ) { - renderer.setRenderTarget(null); - renderer.xr.enabled = state.xrEnabled; - renderer.autoClear = state.autoClear; - } - - // Executes a dyno program specified by generator which is any DynoBlock that - // maps { index: "int" } to { gsplat: Gsplat }. This is called in - // SparkRenderer.updateInternal() to re-generate Gsplats in the scene for - // SplatGenerator instances whose version is newer than what was generated - // for it last time. - generate({ - generator, - base, - count, - renderer, - }: { - generator: GsplatGenerator; - base: number; - count: number; - renderer: THREE.WebGLRenderer; - }): { nextBase: number } { - if (!this.target) { - throw new Error("Target must be initialized with ensureSplats"); - } - if (base + count > this.maxSplats) { - throw new Error("Base + count exceeds maxSplats"); - } - - const { program, material } = this.prepareProgramMaterial(generator); - program.update(); - - const renderState = this.saveRenderState(renderer); - - // Generate the Gsplats in "layer" chunks, in horizontal row ranges, - // that cover the total count of Gsplats. - const nextBase = - Math.ceil((base + count) / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; - const layerSize = SPLAT_TEX_WIDTH * SPLAT_TEX_HEIGHT; - material.uniforms.targetBase.value = base; - material.uniforms.targetCount.value = count; - - // Keep generating layers until we've reached the next generation's base - while (base < nextBase) { - const layer = Math.floor(base / layerSize); - material.uniforms.targetLayer.value = layer; - - const layerBase = layer * layerSize; - const layerYStart = Math.floor((base - layerBase) / SPLAT_TEX_WIDTH); - const layerYEnd = Math.min( - SPLAT_TEX_HEIGHT, - Math.ceil((nextBase - layerBase) / SPLAT_TEX_WIDTH), - ); - - // Render the desired portion of the layer - this.target.scissor.set( - 0, - layerYStart, - SPLAT_TEX_WIDTH, - layerYEnd - layerYStart, - ); - renderer.setRenderTarget(this.target, layer); - renderer.xr.enabled = false; - renderer.autoClear = false; - PackedSplats.fullScreenQuad.render(renderer); - - base += SPLAT_TEX_WIDTH * (layerYEnd - layerYStart); - } - - this.resetRenderState(renderer, renderState); - return { nextBase }; - } - - static programTemplate: DynoProgramTemplate | null = null; - - // Cache for GsplatGenerator programs - static generatorProgram = new Map(); - - // Static full-screen quad for pseudo-compute shader rendering - static fullScreenQuad = new FullScreenQuad( - new THREE.RawShaderMaterial({ visible: false }), - ); -} - -// You can use a PackedSplats as a dyno block using the function -// dyno.readPackedSplats(packedSplats.dyno, dynoIndex) where -// dynoIndex is of type DynoVal<"int">. If you need to be able to change -// the input PackedSplats dynamically, however, you should create a -// DynoPackedSplats, whose property packedSplats you can change to any -// PackedSplats and that will be used in the dyno shader program. - -export const dynoPackedSplats = (packedSplats?: PackedSplats) => - new DynoPackedSplats({ packedSplats }); - -export class DynoPackedSplats extends DynoUniform< - typeof TPackedSplats, - "packedSplats", - { - texture: THREE.DataArrayTexture; - numSplats: number; - rgbMinMaxLnScaleMinMax: THREE.Vector4; - } -> { - packedSplats?: PackedSplats; - - constructor({ packedSplats }: { packedSplats?: PackedSplats } = {}) { - super({ - key: "packedSplats", - type: TPackedSplats, - globals: () => [definePackedSplats], - value: { - texture: PackedSplats.getEmpty(), - numSplats: 0, - rgbMinMaxLnScaleMinMax: new THREE.Vector4( - 0, - 1, - LN_SCALE_MIN, - LN_SCALE_MAX, - ), - }, - update: (value) => { - value.texture = - this.packedSplats?.getTexture() ?? PackedSplats.getEmpty(); - value.numSplats = this.packedSplats?.numSplats ?? 0; - value.rgbMinMaxLnScaleMinMax.set( - this.packedSplats?.splatEncoding?.rgbMin ?? 0, - this.packedSplats?.splatEncoding?.rgbMax ?? 1, - this.packedSplats?.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, - this.packedSplats?.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, - ); - return value; - }, - }); - this.packedSplats = packedSplats; - } -} diff --git a/src/Readback.ts b/src/Readback.ts deleted file mode 100644 index a7a7578..0000000 --- a/src/Readback.ts +++ /dev/null @@ -1,338 +0,0 @@ -import * as THREE from "three"; -import { FullScreenQuad } from "three/addons/postprocessing/Pass.js"; - -import { SPLAT_TEX_HEIGHT, SPLAT_TEX_WIDTH } from "./defines"; -import { type Dyno, OutputRgba8, dynoBlock } from "./dyno"; -import { DynoProgram, DynoProgramTemplate } from "./dyno/program"; -import computeVec4Template from "./shaders/computeVec4.glsl"; -import { getTextureSize } from "./utils"; - -// Readback can be used to run a Dyno program that maps an index to a 32-bit -// RGBA8 value, which is the only allowed, portable readback format for WebGL2. -// Using data packing and conversion you can read back any 32-bit value, which -// Spark uses to read back 2 float16 Gsplat distance values per index. - -export type Rgba8Readback = Dyno<{ index: "int" }, { rgba8: "vec4" }>; - -// Readback can be performed with various typed buffers, making it convenient -// to encode readback data in a variety of formats. - -export type ReadbackBuffer = - | ArrayBuffer - | Uint8Array - | Int8Array - | Uint16Array - | Int16Array - | Uint32Array - | Int32Array - | Float32Array; - -export class Readback { - renderer?: THREE.WebGLRenderer; - target?: THREE.WebGLArrayRenderTarget; - capacity: number; - count: number; - - constructor({ renderer }: { renderer?: THREE.WebGLRenderer } = {}) { - this.renderer = renderer; - this.capacity = 0; - this.count = 0; - } - - dispose() { - if (this.target) { - this.target.dispose(); - this.target = undefined; - } - } - - // Ensure we have a buffer large enough for the readback of count indices. - // Pass in previous bufer of the desired type. - ensureBuffer(count: number, buffer: B): B { - // Readback is performed in a 2D array of pixels, so round up with SPLAT_TEX_WIDTH - const roundedCount = - Math.ceil(Math.max(1, count) / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; - const bytes = roundedCount * 4; - if (buffer.byteLength >= bytes) { - return buffer; - } - - // Need a larger buffer, create a new one of the same type - const newBuffer = new ArrayBuffer(bytes); - if (buffer instanceof ArrayBuffer) { - return newBuffer as B; - } - - const ctor = buffer.constructor as { new (arrayBuffer: ArrayBuffer): B }; - return new ctor(newBuffer) as B; - } - - // Ensure our render target is large enough for the readback of capacity indices. - ensureCapacity(capacity: number) { - const { width, height, depth, maxSplats } = getTextureSize(capacity); - if (!this.target || maxSplats > this.capacity) { - this.dispose(); - this.capacity = maxSplats; - - // The only portable readback format for WebGL2 is RGBA8 - this.target = new THREE.WebGLArrayRenderTarget(width, height, depth, { - depthBuffer: false, - stencilBuffer: false, - generateMipmaps: false, - magFilter: THREE.NearestFilter, - minFilter: THREE.NearestFilter, - }); - this.target.texture.format = THREE.RGBAFormat; - this.target.texture.type = THREE.UnsignedByteType; - this.target.texture.internalFormat = "RGBA8"; - this.target.scissorTest = true; - } - } - - // Get a program and THREE.RawShaderMaterial for a given Rgba8Readback, - // generating it if necessary and caching the result. - prepareProgramMaterial(reader: Rgba8Readback): { - program: DynoProgram; - material: THREE.RawShaderMaterial; - } { - let program = Readback.readbackProgram.get(reader); - if (!program) { - const graph = dynoBlock( - { index: "int" }, - { rgba8: "vec4" }, - ({ index }) => { - reader.inputs.index = index; - const rgba8 = new OutputRgba8({ rgba8: reader.outputs.rgba8 }); - return { rgba8 }; - }, - ); - if (!Readback.programTemplate) { - Readback.programTemplate = new DynoProgramTemplate(computeVec4Template); - } - // Create a program from the template and graph - program = new DynoProgram({ - graph, - inputs: { index: "index" }, - outputs: { rgba8: "target" }, - template: Readback.programTemplate, - }); - Object.assign(program.uniforms, { - targetLayer: { value: 0 }, - targetBase: { value: 0 }, - targetCount: { value: 0 }, - }); - Readback.readbackProgram.set(reader, program); - } - - const material = program.prepareMaterial(); - Readback.fullScreenQuad.material = material; - return { program, material }; - } - - private saveRenderState(renderer: THREE.WebGLRenderer) { - return { - xrEnabled: renderer.xr.enabled, - autoClear: renderer.autoClear, - }; - } - - private resetRenderState( - renderer: THREE.WebGLRenderer, - state: { - xrEnabled: boolean; - autoClear: boolean; - }, - ) { - renderer.setRenderTarget(null); - renderer.xr.enabled = state.xrEnabled; - renderer.autoClear = state.autoClear; - } - - private process({ - count, - material, - }: { count: number; material: THREE.RawShaderMaterial }) { - const renderer = this.renderer; - if (!renderer) { - throw new Error("No renderer"); - } - if (!this.target) { - throw new Error("No target"); - } - - // Run the program in "layer" chunks, in horizontal row ranges, - // that cover the total count of indices. - const layerSize = SPLAT_TEX_WIDTH * SPLAT_TEX_HEIGHT; - material.uniforms.targetBase.value = 0; - material.uniforms.targetCount.value = count; - let baseIndex = 0; - - // Keep generating layers until completed count items - while (baseIndex < count) { - const layer = Math.floor(baseIndex / layerSize); - const layerBase = layer * layerSize; - const layerYEnd = Math.min( - SPLAT_TEX_HEIGHT, - Math.ceil((count - layerBase) / SPLAT_TEX_WIDTH), - ); - material.uniforms.targetLayer.value = layer; - - // Render the desired portion of the layer - this.target.scissor.set(0, 0, SPLAT_TEX_WIDTH, layerYEnd); - renderer.setRenderTarget(this.target, layer); - renderer.xr.enabled = false; - renderer.autoClear = false; - Readback.fullScreenQuad.render(renderer); - - baseIndex += SPLAT_TEX_WIDTH * layerYEnd; - } - - this.count = count; - } - - private async read({ - readback, - }: { readback: B }): Promise { - const renderer = this.renderer; - if (!renderer) { - throw new Error("No renderer"); - } - if (!this.target) { - throw new Error("No target"); - } - - const roundedCount = - Math.ceil(this.count / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; - if (readback.byteLength < roundedCount * 4) { - throw new Error( - `Readback buffer too small: ${readback.byteLength} < ${roundedCount * 4}`, - ); - } - const readbackUint8 = new Uint8Array( - readback instanceof ArrayBuffer ? readback : readback.buffer, - ); - - // We can only read back one 2D array layer of pixels at a time, - // so loop through them, initiate the readback, and collect the - // completion promises. - - const layerSize = SPLAT_TEX_WIDTH * SPLAT_TEX_HEIGHT; - let baseIndex = 0; - const promises = []; - - while (baseIndex < this.count) { - const layer = Math.floor(baseIndex / layerSize); - const layerBase = layer * layerSize; - const layerYEnd = Math.min( - SPLAT_TEX_HEIGHT, - Math.ceil((this.count - layerBase) / SPLAT_TEX_WIDTH), - ); - - renderer.setRenderTarget(this.target, layer); - - // Compute the subarray that this layer of readback corresponds to - const readbackSize = SPLAT_TEX_WIDTH * layerYEnd * 4; - const subReadback = readbackUint8.subarray( - layerBase * 4, - layerBase * 4 + readbackSize, - ); - const promise = renderer?.readRenderTargetPixelsAsync( - this.target, - 0, - 0, - SPLAT_TEX_WIDTH, - layerYEnd, - subReadback, - ); - promises.push(promise); - - baseIndex += SPLAT_TEX_WIDTH * layerYEnd; - } - return Promise.all(promises).then(() => readback); - } - - // Perform render operation to run the Rgba8Readback program - // but don't perform the readback yet. - render({ - reader, - count, - renderer, - }: { reader: Rgba8Readback; count: number; renderer?: THREE.WebGLRenderer }) { - this.renderer = renderer || this.renderer; - if (!this.renderer) { - throw new Error("No renderer"); - } - - this.ensureCapacity(count); - - const { program, material } = this.prepareProgramMaterial(reader); - program.update(); - - const renderState = this.saveRenderState(this.renderer); - this.process({ count, material }); - this.resetRenderState(this.renderer, renderState); - } - - // Perform a readback of the render target, returning a buffer of the - // given type. - async readback({ - readback, - }: { readback: B }): Promise { - if (!this.renderer) { - throw new Error("No renderer"); - } - const renderState = this.saveRenderState(this.renderer); - const promise = this.read({ readback }); - this.resetRenderState(this.renderer, renderState); - return promise; - } - - // Perform a render and readback operation for the given Rgba8Readback, - // and readback buffer (call ensureBuffer first). - async renderReadback({ - reader, - count, - renderer, - readback, - }: { - reader: Rgba8Readback; - count: number; - renderer?: THREE.WebGLRenderer; - readback: B; - }): Promise { - this.renderer = renderer || this.renderer; - if (!this.renderer) { - throw new Error("No renderer"); - } - - this.ensureCapacity(count); - - const { program, material } = this.prepareProgramMaterial(reader); - program.update(); - - const renderState = this.saveRenderState(this.renderer); - - // Generate output - this.process({ count, material }); - - // Initiate readback - const promise = this.read({ readback }); - - this.resetRenderState(this.renderer, renderState); - return promise; - } - - getTexture(): THREE.DataArrayTexture | undefined { - return this.target?.texture; - } - - static programTemplate: DynoProgramTemplate | null = null; - - // Cache for Rgba8Readback programs - static readbackProgram = new Map(); - - // Static full-screen quad for pseudo-compute shader rendering - static fullScreenQuad = new FullScreenQuad( - new THREE.RawShaderMaterial({ visible: false }), - ); -} diff --git a/src/RgbaArray.ts b/src/RgbaArray.ts deleted file mode 100644 index 49446df..0000000 --- a/src/RgbaArray.ts +++ /dev/null @@ -1,283 +0,0 @@ -import * as THREE from "three"; - -import { DynoPackedSplats, type PackedSplats } from "./PackedSplats"; -import { Readback, type Rgba8Readback } from "./Readback"; -import { SPLAT_TEX_WIDTH } from "./defines"; -import { - Dyno, - type DynoBlock, - DynoInt, - DynoUniform, - type DynoVal, - add, - dynoBlock, - readPackedSplatRange, - splitGsplat, - unindent, - unindentLines, -} from "./dyno"; -import { getTextureSize } from "./utils"; - -// An RgbaArray is a collection of ordered RGBA8 values, which can be used as a dyno -// data source, for example for recoloring Gsplats via SplatMesh.splatRgba. -// It can be instantiated from a Uint8Array of RGBA8 values, or it can be -// generated using a Rgba8Readback dyno program. - -export type RgbaArrayOptions = { - // Reserve space for at least this many RGBA values. - capacity?: number; - // Use the provided array of RGBA8 values as the source. - array?: Uint8Array; - // The number of actual RGBA8 values in the array. - count?: number; -}; - -export class RgbaArray { - capacity = 0; - count = 0; - array: Uint8Array | null = null; - - readback: Readback | null = null; - source: THREE.DataArrayTexture | null = null; - // Set to true if source array is updated to have it upload to GPU - needsUpdate = true; - - // Use this as a TRgbaArray in a dyno graph - dyno: DynoUniform; - - constructor(options: RgbaArrayOptions = {}) { - this.dyno = new DynoUniform({ - key: "rgbaArray", - type: TRgbaArray, - globals: () => [defineRgbaArray], - value: { - texture: RgbaArray.getEmpty(), - count: 0, - }, - update: (value) => { - value.texture = - this.readback?.getTexture() ?? this.source ?? RgbaArray.getEmpty(); - value.count = this.count; - return value; - }, - }); - - if (options.array) { - // Initialize with given array - this.array = options.array; - this.capacity = Math.floor(this.array.length / 4); - this.capacity = - Math.floor(this.capacity / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; - this.count = Math.min( - this.capacity, - options.count ?? Number.POSITIVE_INFINITY, - ); - } else { - this.capacity = options.capacity ?? 0; - this.count = 0; - } - } - - // Free up resources - dispose() { - if (this.readback) { - this.readback.dispose(); - this.readback = null; - } - if (this.source) { - this.source.dispose(); - this.source = null; - } - } - - // Ensure that our array is large enough to hold capacity RGBA8 values. - ensureCapacity(capacity: number): Uint8Array { - if (!this.array || capacity > (this.array?.length ?? 0) / 4) { - this.capacity = getTextureSize(capacity).maxSplats; - const newArray = new Uint8Array(this.capacity * 4); - if (this.array) { - // Copy over existing data - newArray.set(this.array); - } - this.array = newArray; - } - return this.array; - } - - // Get the THREE.DataArrayTexture from either the readback or the source. - getTexture(): THREE.DataArrayTexture { - let texture = this.readback?.getTexture(); - if (this.source || this.array) { - texture = this.maybeUpdateSource(); - } - return texture ?? RgbaArray.getEmpty(); - } - - // Create or get a THREE.DataArrayTexture from the data array. - private maybeUpdateSource(): THREE.DataArrayTexture { - if (!this.array) { - throw new Error("No array"); - } - - if (this.needsUpdate || !this.source) { - this.needsUpdate = false; - - if (this.source) { - const { width, height, depth } = this.source.image; - if (this.capacity !== width * height * depth) { - this.source.dispose(); - this.source = null; - } - } - if (!this.source) { - const { width, height, depth } = getTextureSize(this.capacity); - this.source = new THREE.DataArrayTexture( - this.array, - width, - height, - depth, - ); - this.source.format = THREE.RGBAFormat; - this.source.type = THREE.UnsignedByteType; - this.source.internalFormat = "RGBA8"; - this.source.needsUpdate = true; - } else if (this.array.buffer !== this.source.image.data.buffer) { - this.source.image.data = new Uint8Array(this.array.buffer); - } - this.source.needsUpdate = true; - } - return this.source; - } - - // Generate the RGBA8 values from a Rgba8Readback dyno program. - render({ - reader, - count, - renderer, - }: { reader: Rgba8Readback; count: number; renderer: THREE.WebGLRenderer }) { - if (!this.readback) { - this.readback = new Readback({ renderer }); - } - this.readback.render({ reader, count, renderer }); - this.capacity = this.readback.capacity; - this.count = this.readback.count; - } - - // Extract the RGBA8 values from a PackedSplats collection. - fromPackedSplats({ - packedSplats, - base, - count, - renderer, - }: { - packedSplats: PackedSplats; - base: number; - count: number; - renderer: THREE.WebGLRenderer; - }) { - const { dynoSplats, dynoBase, dynoCount, reader } = RgbaArray.makeDynos(); - dynoSplats.packedSplats = packedSplats; - dynoBase.value = base; - dynoCount.value = count; - this.render({ reader, count, renderer }); - return this; - } - - // Read back the RGBA8 values from the readback buffer. - async read(): Promise { - if (!this.readback) { - throw new Error("No readback"); - } - if (!this.array || this.array.length < this.count * 4) { - this.array = new Uint8Array(this.capacity * 4); - } - const result = await this.readback.readback({ readback: this.array }); - return result.subarray(0, this.count * 4); - } - - private static emptySource: THREE.DataArrayTexture | null = null; - - // Can be used where you need an uninitialized THREE.DataArrayTexture like - // a uniform you will update with the result of this.getTexture() later. - static getEmpty(): THREE.DataArrayTexture { - if (!RgbaArray.emptySource) { - const emptyArray = new Uint8Array(1 * 4); - RgbaArray.emptySource = new THREE.DataArrayTexture(emptyArray, 1, 1, 1); - RgbaArray.emptySource.format = THREE.RGBAFormat; - RgbaArray.emptySource.type = THREE.UnsignedByteType; - RgbaArray.emptySource.internalFormat = "RGBA8"; - RgbaArray.emptySource.needsUpdate = true; - } - return RgbaArray.emptySource; - } - - private static dynos: { - dynoSplats: DynoPackedSplats; - dynoBase: DynoInt; - dynoCount: DynoInt; - reader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; - } | null = null; - - // Create a dyno program that can extract RGBA8 values from a PackedSplats - private static makeDynos() { - if (!RgbaArray.dynos) { - const dynoSplats = new DynoPackedSplats(); - const dynoBase = new DynoInt({ value: 0 }); - const dynoCount = new DynoInt({ value: 0 }); - const reader = dynoBlock( - { index: "int" }, - { rgba8: "vec4" }, - ({ index }) => { - if (!index) { - throw new Error("index is undefined"); - } - index = add(index, dynoBase); - const gsplat = readPackedSplatRange( - dynoSplats, - index, - dynoBase, - dynoCount, - ); - return { rgba8: splitGsplat(gsplat).outputs.rgba }; - }, - ); - RgbaArray.dynos = { dynoSplats, dynoBase, dynoCount, reader }; - } - return RgbaArray.dynos; - } -} - -// Dyno types and definitions - -export const TRgbaArray = { type: "RgbaArray" } as { type: "RgbaArray" }; - -export const defineRgbaArray = unindent(` - struct RgbaArray { - sampler2DArray texture; - int count; - }; -`); - -export function readRgbaArray( - rgba: DynoVal, - index: DynoVal<"int">, -): DynoVal<"vec4"> { - const dyno = new Dyno< - { rgba: typeof TRgbaArray; index: "int" }, - { rgba: "vec4" } - >({ - inTypes: { rgba: TRgbaArray, index: "int" }, - outTypes: { rgba: "vec4" }, - inputs: { rgba, index }, - globals: () => [defineRgbaArray], - statements: ({ inputs, outputs }) => - unindentLines(` - if ((index >= 0) && (index < ${inputs.rgba}.count)) { - ${outputs.rgba} = texelFetch(${inputs.rgba}.texture, splatTexCoord(index), 0); - } else { - ${outputs.rgba} = vec4(0.0, 0.0, 0.0, 0.0); - } - `), - }); - return dyno.outputs.rgba; -} diff --git a/src/SparkRenderer.ts b/src/SparkRenderer.ts deleted file mode 100644 index 3b6368b..0000000 --- a/src/SparkRenderer.ts +++ /dev/null @@ -1,1073 +0,0 @@ -import * as THREE from "three"; - -import { - DEFAULT_SPLAT_ENCODING, - PackedSplats, - type SplatEncoding, -} from "./PackedSplats"; -import { RgbaArray } from "./RgbaArray"; -import { SparkViewpoint, type SparkViewpointOptions } from "./SparkViewpoint"; -import { type GeneratorMapping, SplatAccumulator } from "./SplatAccumulator"; -import { SplatEdit } from "./SplatEdit"; -import { SplatGenerator, SplatModifier } from "./SplatGenerator"; -import { SplatGeometry } from "./SplatGeometry"; -import { SplatMesh } from "./SplatMesh"; -import { LN_SCALE_MAX, LN_SCALE_MIN } from "./defines"; -import { - DynoVec3, - DynoVec4, - Gsplat, - TPackedSplats, - dynoBlock, - readPackedSplat, - transformGsplat, -} from "./dyno"; -import { getShaders } from "./shaders"; -import { - averagePositions, - averageQuaternions, - cloneClock, - withinCoorientDist, -} from "./utils"; - -// SparkRenderer aggregates splats from multiple generators into a single -// accumulated collection per frame. In normal operation we only need a -// maximum of 3 accumulators: One currently being viewed, one currently -// being sorted, and one more for generating the next frame. Accumulators -// must be "released" by each viewpoint using it, so in unusual cases -// such as slow render-outs, we may want to allow more than 3 so the -// pipeline can continue generating new frames, but we limit to a maximum -// of 5 to avoid excessive memory usage. -const MAX_ACCUMULATORS = 5; - -export type SparkRendererOptions = { - /** - * Pass in your THREE.WebGLRenderer instance so Spark can perform work - * outside the usual render loop. Should be created with antialias: false - * (default setting) as WebGL anti-aliasing doesn't improve Gaussian Splatting - * rendering and significantly reduces performance. - */ - renderer: THREE.WebGLRenderer; - /** - * Whether to use premultiplied alpha when accumulating splat RGB - * @default true - */ - premultipliedAlpha?: boolean; - /** - * Pass in a THREE.Clock to synchronize time-based effects across different - * systems. Alternatively, you can set the SparkRenderer properties time and - * deltaTime directly. (default: new THREE.Clock) - */ - clock?: THREE.Clock; - /** - * Controls whether to check and automatically update Gsplat collection after - * each frame render. - * @default true - */ - autoUpdate?: boolean; - /** - * Controls whether to update the Gsplats before or after rendering. For WebXR - * this must be false in order to complete rendering as soon as possible. - * @default false - */ - preUpdate?: boolean; - /** - * Distance threshold for SparkRenderer movement triggering a Gsplat update at - * the new origin. - * @default 1.0 - */ - originDistance?: number; - /** - * Maximum standard deviations from the center to render Gaussians. Values - * Math.sqrt(5)..Math.sqrt(8) produce good results and can be tweaked for - * performance. - * @default Math.sqrt(8) - */ - maxStdDev?: number; - /** - * Minimum pixel radius for splat rendering. - * @default 0.0 - */ - minPixelRadius?: number; - /** - * Maximum pixel radius for splat rendering. - * @default 512.0 - */ - maxPixelRadius?: number; - /** - * Minimum alpha value for splat rendering. - * @default 0.5 * (1.0 / 255.0) - */ - minAlpha?: number; - /** - * Enable 2D Gaussian splatting rendering ability. When this mode is enabled, - * any scale x/y/z component that is exactly 0 (minimum quantized value) results - * in the other two non-0 axis being interpreted as an oriented 2D Gaussian Splat, - * rather instead of the usual projected 3DGS Z-slice. When reading PLY files, - * scale values less than e^-30 will be interpreted as 0. - * @default false - */ - enable2DGS?: boolean; - /** - * Scalar value to add to 2D splat covariance diagonal, effectively blurring + - * enlarging splats. In scenes trained without the Gsplat anti-aliasing tweak - * this value was typically 0.3, but with anti-aliasing it is 0.0 - * @default 0.0 - */ - preBlurAmount?: number; - /** - * Scalar value to add to 2D splat covarianve diagonal, with opacity adjustment - * to correctly account for "blurring" when anti-aliasing. Typically 0.3 - * (equivalent to approx 0.5 pixel radius) in scenes trained with anti-aliasing. - */ - blurAmount?: number; - /** - * Depth-of-field distance to focal plane - */ - focalDistance?: number; - /** - * Full-width angle of aperture opening (in radians), 0.0 to disable - * @default 0.0 - */ - apertureAngle?: number; - /** - * Modulate Gaussian kernel falloff. 0 means "no falloff, flat shading", - * while 1 is the normal Gaussian kernel. - * @default 1.0 - */ - falloff?: number; - /** - * X/Y clipping boundary factor for Gsplat centers against view frustum. - * 1.0 clips any centers that are exactly out of bounds, while 1.4 clips - * centers that are 40% beyond the bounds. - * @default 1.4 - */ - clipXY?: number; - /** - * Parameter to adjust projected splat scale calculation to match other renderers, - * similar to the same parameter in the MKellogg 3DGS renderer. Higher values will - * tend to sharpen the splats. A value 2.0 can be used to match the behavior of - * the PlayCanvas renderer. - * @default 1.0 - */ - focalAdjustment?: number; - /** - * Configures the SparkViewpointOptions for the default SparkViewpoint - * associated with this SparkRenderer. Notable option: sortRadial (sort by - * radial distance or Z-depth) - */ - view?: SparkViewpointOptions; - /** - * Override the default splat encoding ranges for the PackedSplats. - * (default: undefined) - */ - splatEncoding?: SplatEncoding; -}; - -export class SparkRenderer extends THREE.Mesh { - renderer: THREE.WebGLRenderer; - premultipliedAlpha: boolean; - material: THREE.ShaderMaterial; - uniforms: ReturnType; - - autoUpdate: boolean; - preUpdate: boolean; - needsUpdate: boolean; - originDistance: number; - maxStdDev: number; - minPixelRadius: number; - maxPixelRadius: number; - minAlpha: number; - enable2DGS: boolean; - preBlurAmount: number; - blurAmount: number; - focalDistance: number; - apertureAngle: number; - falloff: number; - clipXY: number; - focalAdjustment: number; - splatEncoding: SplatEncoding; - - splatTexture: null | { - enable?: boolean; - texture?: THREE.Data3DTexture; - multiply?: THREE.Matrix2; - add?: THREE.Vector2; - near?: number; - far?: number; - mid?: number; - } = null; - - time?: number; - deltaTime?: number; - clock: THREE.Clock; - - // Latest Gsplat collection being displayed - active: SplatAccumulator; - // Free list of accumulators for reuse - private freeAccumulators: SplatAccumulator[]; - // Total number of accumulators currently allocated - private accumulatorCount: number; - // Default SparkViewpoint used for rendering to the canvas - defaultView: SparkViewpoint; - // List of SparkViewpoints with autoUpdate enabled - autoViewpoints: SparkViewpoint[] = []; - - // Dynos used to transform Gsplats to the accumulator coordinate system - private rotateToAccumulator = new DynoVec4({ value: new THREE.Quaternion() }); - private translateToAccumulator = new DynoVec3({ value: new THREE.Vector3() }); - private modifier: SplatModifier; - - // Last rendered frame number so we know when we're rendering a new frame - private lastFrame = -1; - // Last update timestamp to compute deltaTime - private lastUpdateTime: number | null = null; - // List of cameras used for the current viewpoint (for WebXR) - private defaultCameras: THREE.Matrix4[] = []; - private lastStochastic: boolean | null = null; - - // Should be set to the defaultView, but can be temporarily changed to another - // viewpoint using prepareViewpoint() for rendering from a different viewpoint. - viewpoint: SparkViewpoint; - - // Holds data needed to perform a scheduled Gsplat update. - private pendingUpdate = { - scene: null as THREE.Scene | null, - originToWorld: new THREE.Matrix4(), - timeoutId: -1, - }; - - // Internal SparkViewpoint used for environment map rendering. - private envViewpoint: SparkViewpoint | null = null; - - // Data and buffers used for environment map rendering - private static cubeRender: { - target: THREE.WebGLCubeRenderTarget; - camera: THREE.CubeCamera; - near: number; - far: number; - } | null = null; - private static pmrem: THREE.PMREMGenerator | null = null; - - static EMPTY_SPLAT_TEXTURE = new THREE.Data3DTexture(); - - constructor(options: SparkRendererOptions) { - const uniforms = SparkRenderer.makeUniforms(); - const shaders = getShaders(); - const premultipliedAlpha = options.premultipliedAlpha ?? true; - const material = new THREE.ShaderMaterial({ - glslVersion: THREE.GLSL3, - vertexShader: shaders.splatVertex, - fragmentShader: shaders.splatFragment, - uniforms, - premultipliedAlpha, - transparent: true, - depthTest: true, - depthWrite: false, - side: THREE.DoubleSide, - }); - - super(EMPTY_GEOMETRY, material); - // Disable frustum culling because we want to always draw them all - // and cull Gsplats individually in the shader - this.frustumCulled = false; - - this.renderer = options.renderer; - this.material = material; - this.uniforms = uniforms; - - // Create a Gsplat modifier that takes the output of any SplatGenerator - // and transforms them into the accumulator's coordinate system - const modifier = dynoBlock( - { gsplat: Gsplat }, - { gsplat: Gsplat }, - ({ gsplat }) => { - if (!gsplat) { - throw new Error("gsplat not defined"); - } - gsplat = transformGsplat(gsplat, { - rotate: this.rotateToAccumulator, - translate: this.translateToAccumulator, - }); - return { gsplat }; - }, - ); - this.modifier = new SplatModifier(modifier); - - this.premultipliedAlpha = premultipliedAlpha; - this.autoUpdate = options.autoUpdate ?? true; - this.preUpdate = options.preUpdate ?? false; - this.needsUpdate = false; - this.originDistance = options.originDistance ?? 1; - this.maxStdDev = options.maxStdDev ?? Math.sqrt(8.0); - this.minPixelRadius = options.minPixelRadius ?? 0.0; - this.maxPixelRadius = options.maxPixelRadius ?? 512.0; - this.minAlpha = options.minAlpha ?? 0.5 * (1.0 / 255.0); - this.enable2DGS = options.enable2DGS ?? false; - this.preBlurAmount = options.preBlurAmount ?? 0.0; - this.blurAmount = options.blurAmount ?? 0.3; - this.focalDistance = options.focalDistance ?? 0.0; - this.apertureAngle = options.apertureAngle ?? 0.0; - this.falloff = options.falloff ?? 1.0; - this.clipXY = options.clipXY ?? 1.4; - this.focalAdjustment = options.focalAdjustment ?? 1.0; - this.splatEncoding = options.splatEncoding ?? { ...DEFAULT_SPLAT_ENCODING }; - - this.active = new SplatAccumulator(); - this.active.refCount = 1; - this.accumulatorCount = 1; - this.freeAccumulators = []; - // Start with the minimum of 2 total accumulators - for (let count = 0; count < 1; ++count) { - this.freeAccumulators.push(new SplatAccumulator()); - this.accumulatorCount += 1; - } - - // Create a default SparkViewpoint that is used when we call render() - // on the scene and has the sorted Gsplat collection from that viewpoint. - this.defaultView = new SparkViewpoint({ - ...options.view, - autoUpdate: true, - spark: this, - }); - this.viewpoint = this.defaultView; - this.prepareViewpoint(this.viewpoint); - - this.clock = options.clock ? cloneClock(options.clock) : new THREE.Clock(); - } - - static makeUniforms() { - // Create uniforms used for Gsplat vertex and fragment shaders - const uniforms = { - // Size of render viewport in pixels - renderSize: { value: new THREE.Vector2() }, - // Near and far plane distances - near: { value: 0.1 }, - far: { value: 1000.0 }, - // Total number of Gsplats in packedSplats to render - numSplats: { value: 0 }, - // SplatAccumulator to view transformation quaternion - renderToViewQuat: { value: new THREE.Quaternion() }, - // SplatAccumulator to view transformation translation - renderToViewPos: { value: new THREE.Vector3() }, - // Maximum distance (in stddevs) from Gsplat center to render - maxStdDev: { value: 1.0 }, - // Minimum pixel radius for splat rendering - minPixelRadius: { value: 0.0 }, - // Maximum pixel radius for splat rendering - maxPixelRadius: { value: 512.0 }, - // Minimum alpha value for splat rendering - minAlpha: { value: 0.5 * (1.0 / 255.0) }, - // Enable stochastic splat rendering - stochastic: { value: false }, - // Enable interpreting 0-thickness Gsplats as 2DGS - enable2DGS: { value: false }, - // Add to projected 2D splat covariance diagonal (thickens and brightens) - preBlurAmount: { value: 0.0 }, - // Add to 2D splat covariance diagonal and adjust opacity (anti-aliasing) - blurAmount: { value: 0.3 }, - // Depth-of-field distance to focal plane - focalDistance: { value: 0.0 }, - // Full-width angle of aperture opening (in radians) - apertureAngle: { value: 0.0 }, - // Modulate Gaussian kernal falloff. 0 means "no falloff, flat shading", - // 1 is normal e^-x^2 falloff. - falloff: { value: 1.0 }, - // Clip Gsplats that are clipXY times beyond the +-1 frustum bounds - clipXY: { value: 1.4 }, - // Debug renderSize scale factor - focalAdjustment: { value: 1.0 }, - // Enable splat texture rendering - splatTexEnable: { value: false }, - // Splat texture to render - splatTexture: { type: "t", value: SparkRenderer.EMPTY_SPLAT_TEXTURE }, - // Splat texture UV transform (multiply) - splatTexMul: { value: new THREE.Matrix2() }, - // Splat texture UV transform (add) - splatTexAdd: { value: new THREE.Vector2() }, - // Splat texture near plane distance - splatTexNear: { value: 0.1 }, - // Splat texture far plane distance - splatTexFar: { value: 1000.0 }, - // Splat texture mid plane distance, or 0.0 to disable - splatTexMid: { value: 0.0 }, - // Gsplat collection to render - packedSplats: { type: "t", value: PackedSplats.getEmpty() }, - // Splat encoding ranges - rgbMinMaxLnScaleMinMax: { value: new THREE.Vector4() }, - // Time in seconds for time-based effects - time: { value: 0 }, - // Delta time in seconds since last frame - deltaTime: { value: 0 }, - // Whether to encode Gsplat with linear RGB (for environment mapping) - encodeLinear: { value: false }, - // Debug flag that alternates each frame - debugFlag: { value: false }, - }; - return uniforms; - } - - private canAllocAccumulator(): boolean { - // Returns true if can allocate an accumulator immediately - return ( - this.freeAccumulators.length > 0 || - this.accumulatorCount < MAX_ACCUMULATORS - ); - } - - private maybeAllocAccumulator(): SplatAccumulator | null { - // Allocate an accumulator immediately if possible, else return null - let accumulator = this.freeAccumulators.pop(); - if (accumulator === undefined) { - if (this.accumulatorCount >= MAX_ACCUMULATORS) { - return null; - } - accumulator = new SplatAccumulator(); - this.accumulatorCount += 1; - } - accumulator.refCount = 1; - return accumulator; - } - - releaseAccumulator(accumulator: SplatAccumulator) { - // Decrement reference count and recycle if no longer in use - accumulator.refCount -= 1; - if (accumulator.refCount === 0) { - this.freeAccumulators.push(accumulator); - } - } - - newViewpoint(options: SparkViewpointOptions) { - // Create a new SparkViewpoint for this SparkRenderer. - // Note that every SparkRenderer has an initial spark.defaultView: SparkViewpoint - // from construction, which is used for the default canvas render loop. - // Calling this method allows you to create additional viewpoints, which can be - // updated automatically each frame (performing Gsplat sorting every time there - // is an update), or updated on-demand for controlled rendering for video render - // or similar applications. - return new SparkViewpoint({ ...options, spark: this }); - } - - onBeforeRender( - renderer: THREE.WebGLRenderer, - scene: THREE.Scene, - camera: THREE.Camera, - ) { - // Called by Three.js before rendering this SparkRenderer. - // At this point we can't modify the geometry or material, all these must - // be set in the scene already before this is called. Update the uniforms - // to render the Gsplats from the current active viewpoint. - const time = this.time ?? this.clock.getElapsedTime(); - const deltaTime = time - (this.viewpoint.lastTime ?? time); - this.viewpoint.lastTime = time; - - const frame = renderer.info.render.frame; - const isNewFrame = frame !== this.lastFrame; - this.lastFrame = frame; - - const viewpoint = this.viewpoint; - if (viewpoint === this.defaultView) { - // When rendering is triggered on the default viewpoint, - // perform automatic updates. - if (isNewFrame) { - if (!renderer.xr.isPresenting) { - // Non-WebXR mode, just a single camera - this.defaultView.viewToWorld = camera.matrixWorld.clone(); - this.defaultCameras = [this.defaultView.viewToWorld]; - } else { - // In WebXR mode we are called multiple times, once for each eye, - // so use their average to compute the sort center. - const cameras = renderer.xr.getCamera().cameras; - this.defaultCameras = cameras.map((camera) => camera.matrixWorld); - this.defaultView.viewToWorld = - averageOriginToWorlds(this.defaultCameras) ?? new THREE.Matrix4(); - } - } - - if (this.autoUpdate) { - this.update({ scene, viewToWorld: this.defaultView.viewToWorld }); - } - } - - // Update uniforms for rendering - - if (isNewFrame) { - // Keep these uniforms the same for both eyes if in WebXR - if (this.material.premultipliedAlpha !== this.premultipliedAlpha) { - this.material.premultipliedAlpha = this.premultipliedAlpha; - this.material.needsUpdate = true; - } - this.uniforms.time.value = time; - this.uniforms.deltaTime.value = deltaTime; - // Alternating debug flag that can aid in visual debugging - this.uniforms.debugFlag.value = (performance.now() / 1000.0) % 2.0 < 1.0; - - if (viewpoint.display && viewpoint.stochastic) { - (this.geometry as SplatGeometry).instanceCount = - this.uniforms.numSplats.value; - } - } - - if (viewpoint.target) { - // Rendering to a texture target, so its dimensions - this.uniforms.renderSize.value.set( - viewpoint.target.width, - viewpoint.target.height, - ); - } else { - // Rendering to the canvas or WebXR - const renderSize = renderer.getDrawingBufferSize( - this.uniforms.renderSize.value, - ); - if (renderSize.x === 1 && renderSize.y === 1) { - // WebXR mode on Apple Vision Pro returns 1x1 when presenting. - // Use a different means to figure out the render size. - const baseLayer = renderer.xr.getSession()?.renderState.baseLayer; - if (baseLayer) { - renderSize.x = baseLayer.framebufferWidth; - renderSize.y = baseLayer.framebufferHeight; - } - } - } - - // Update uniforms from instance properties - const typedCamera = camera as - | THREE.PerspectiveCamera - | THREE.OrthographicCamera; - this.uniforms.near.value = typedCamera.near; - this.uniforms.far.value = typedCamera.far; - this.uniforms.encodeLinear.value = viewpoint.encodeLinear; - this.uniforms.maxStdDev.value = this.maxStdDev; - this.uniforms.minPixelRadius.value = this.minPixelRadius; - this.uniforms.maxPixelRadius.value = this.maxPixelRadius; - this.uniforms.minAlpha.value = this.minAlpha; - this.uniforms.stochastic.value = viewpoint.stochastic; - this.uniforms.enable2DGS.value = this.enable2DGS; - this.uniforms.preBlurAmount.value = this.preBlurAmount; - this.uniforms.blurAmount.value = this.blurAmount; - this.uniforms.focalDistance.value = this.focalDistance; - this.uniforms.apertureAngle.value = this.apertureAngle; - this.uniforms.falloff.value = this.falloff; - this.uniforms.clipXY.value = this.clipXY; - this.uniforms.focalAdjustment.value = this.focalAdjustment; - - if (this.lastStochastic !== !viewpoint.stochastic) { - this.lastStochastic = !viewpoint.stochastic; - this.material.transparent = !viewpoint.stochastic; - this.material.depthWrite = viewpoint.stochastic; - this.material.needsUpdate = true; - } - - if (this.splatTexture) { - const { enable, texture, multiply, add, near, far, mid } = - this.splatTexture; - if (enable && texture) { - this.uniforms.splatTexEnable.value = true; - this.uniforms.splatTexture.value = texture; - if (multiply) { - this.uniforms.splatTexMul.value.fromArray(multiply.elements); - } else { - this.uniforms.splatTexMul.value.set( - 0.5 / this.maxStdDev, - 0, - 0, - 0.5 / this.maxStdDev, - ); - } - this.uniforms.splatTexAdd.value.set(add?.x ?? 0.5, add?.y ?? 0.5); - this.uniforms.splatTexNear.value = near ?? this.uniforms.near.value; - this.uniforms.splatTexFar.value = far ?? this.uniforms.far.value; - this.uniforms.splatTexMid.value = mid ?? 0.0; - } else { - this.uniforms.splatTexEnable.value = false; - this.uniforms.splatTexture.value = SparkRenderer.EMPTY_SPLAT_TEXTURE; - } - } else { - this.uniforms.splatTexEnable.value = false; - this.uniforms.splatTexture.value = SparkRenderer.EMPTY_SPLAT_TEXTURE; - } - - // Calculate the transform from the accumulator to the current camera - const accumToWorld = - viewpoint.display?.accumulator.toWorld ?? new THREE.Matrix4(); - const worldToCamera = camera.matrixWorld.clone().invert(); - const originToCamera = accumToWorld.clone().premultiply(worldToCamera); - originToCamera.decompose( - this.uniforms.renderToViewPos.value, - this.uniforms.renderToViewQuat.value, - new THREE.Vector3(), - ); - } - - // Update the uniforms for the given viewpoint. - // Note that the client expects to be able to call render() at any point - // to update the canvas, so we must switch the viewpoint back to - // defaultView when we're finished. - prepareViewpoint(viewpoint?: SparkViewpoint) { - this.viewpoint = viewpoint ?? this.viewpoint; - - if (this.viewpoint.display) { - const { accumulator, geometry } = this.viewpoint.display; - this.uniforms.numSplats.value = accumulator.splats.numSplats; - this.uniforms.packedSplats.value = accumulator.splats.getTexture(); - this.uniforms.rgbMinMaxLnScaleMinMax.value.set( - accumulator.splats.splatEncoding?.rgbMin ?? 0.0, - accumulator.splats.splatEncoding?.rgbMax ?? 1.0, - accumulator.splats.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, - accumulator.splats.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, - ); - this.geometry = geometry; - this.material.transparent = !this.viewpoint.stochastic; - this.material.depthWrite = this.viewpoint.stochastic; - this.material.needsUpdate = true; - } else { - // No Gsplats to display for this viewpoint yet - this.uniforms.numSplats.value = 0; - this.uniforms.packedSplats.value = PackedSplats.getEmpty(); - this.geometry = EMPTY_GEOMETRY; - } - } - - // If spark.autoUpdate is false then you must manually call - // spark.update({ scene }) to have the scene Gsplats be re-generated. - update({ - scene, - viewToWorld, - }: { scene: THREE.Scene; viewToWorld?: THREE.Matrix4 }) { - // Compute the transform for the SparkRenderer to use as origin - // for Gsplat generation and accumulation. - const originToWorld = this.matrixWorld; - - // Either do the update now, or in the next "tick" depending on preUpdate - if (this.preUpdate) { - this.updateInternal({ - scene, - originToWorld: originToWorld.clone(), - viewToWorld, - }); - } else { - // Pass the update parameters to be performed on the next tick - this.pendingUpdate.scene = scene; - this.pendingUpdate.originToWorld.copy(originToWorld); - - // Schedule a timeout if there isn't one already - if (this.pendingUpdate.timeoutId === -1) { - this.pendingUpdate.timeoutId = setTimeout(() => { - const { scene, originToWorld } = this.pendingUpdate; - this.pendingUpdate.scene = null; - this.pendingUpdate.timeoutId = -1; - const updated = this.updateInternal({ - scene: scene as THREE.Scene, - originToWorld, - viewToWorld, - }); - - if (updated) { - // Flush to encourage eager execution - const gl = this.renderer.getContext() as WebGL2RenderingContext; - gl.flush(); - } - }, 1); - } - } - } - - updateInternal({ - scene, - originToWorld, - viewToWorld, - }: { - scene: THREE.Scene; - originToWorld?: THREE.Matrix4; - viewToWorld?: THREE.Matrix4; - }): boolean { - if (!this.canAllocAccumulator()) { - // We don't have any available accumulators because of sorting - // back pressure, so don't update this time but try again next time. - // Signal update not attempted. - return false; - } - - // Figure out the frame of the SparkRenderer and current view - if (!originToWorld) { - originToWorld = this.active.toWorld; - } - viewToWorld = viewToWorld ?? originToWorld.clone(); - - const time = this.time ?? this.clock.getElapsedTime(); - const deltaTime = time - (this.lastUpdateTime ?? time); - this.lastUpdateTime = time; - - // Create a lookup from last active SplatGenerator to Gsplat mapping record - const activeMapping = this.active.mapping.reduce((map, record) => { - map.set(record.node, record); - return map; - }, new Map()); - - // Traverse visible scene to find all SplatGenerators and global SplatEdits - const { generators, visibleGenerators, globalEdits } = - this.compileScene(scene); - - // Let all SplatGenerators run their frameUpdate() method - for (const object of generators) { - object.frameUpdate?.({ - object, - time, - deltaTime, - viewToWorld, - globalEdits, - }); - } - - const visibleGenHash = new Set(visibleGenerators.map((g) => g.uuid)); - - // Make sure we have new version numbers for any objects with either - // generator or numSplats that have changed since the last frame. - for (const object of generators) { - const current = activeMapping.get(object); - const isVisible = object.generator && visibleGenHash.has(object.uuid); - const numSplats = isVisible ? object.numSplats : 0; - if ( - this.needsUpdate || - object.generator !== current?.generator || - numSplats !== current?.count - ) { - object.updateVersion(); - } - } - - // Check if the origin is within the maximum allowed distance before - // we trigger an update. - const originUpdate = !withinCoorientDist({ - matrix1: originToWorld, - matrix2: this.active.toWorld, - maxDistance: this.originDistance, - }); - - // Check if we need any update at all - const needsUpdate = - this.needsUpdate || - originUpdate || - generators.length !== activeMapping.size || - generators.some((g) => g.version !== activeMapping.get(g)?.version); - this.needsUpdate = false; - - let accumulator: SplatAccumulator | null = null; - if (needsUpdate) { - // Need to update, so allocate an accumulator - accumulator = this.maybeAllocAccumulator(); - if (!accumulator) { - // This should never happen since we checked canAllocAccumulator() above - throw new Error("Unreachable"); - } - - // Compute whether our view frame has changed enough to warrant - // doing a Gsplat sort. Check both distance epsilon and - // minimum co-orientation (dot product of quaternions) - const originChanged = !withinCoorientDist({ - matrix1: originToWorld, - matrix2: accumulator.toWorld, - maxDistance: 0.00001, - minCoorient: 0.99999, - }); - - // Compute an ordering of the generators with the rough goal - // of keeping unchanging generators near the front to minimize - // the number of Gsplats that need to be regenerated. - const sorted = visibleGenerators - .map((g, gIndex): [number, number, SplatGenerator] => { - const lastGen = activeMapping.get(g); - // If no previous generator, sort by absolute version, which will - // tend to push frequently updated generators toward the end - return !lastGen - ? [Number.POSITIVE_INFINITY, g.version, g] - : // Sort by version deltas then by previous ordering in the mapping, - // attempting to keep unchanging generators near the front - // to improve our chances of avoiding a re-generation. - [g.version - lastGen.version, lastGen.base, g]; - }) - .sort((a, b) => { - // Sort by first then second element of the tuple - if (a[0] !== b[0]) { - return a[0] - b[0]; - } - return a[1] - b[1]; - }); - const genOrder = sorted.map(([_version, _seq, g]) => g); - - // Compute sequential layout of generated splats - const splatCounts = genOrder.map((g) => g.numSplats); - const { maxSplats, mapping } = - accumulator.splats.generateMapping(splatCounts); - const newGenerators = genOrder.map((node, gIndex) => { - const { base, count } = mapping[gIndex]; - return { - node, - generator: node.generator, - version: node.version, - base, - count, - }; - }); - - // Compute worldToAccumulator origin transform (no scale) - originToWorld - .clone() - .invert() - .decompose( - this.translateToAccumulator.value, - this.rotateToAccumulator.value, - new THREE.Vector3(), - ); - - // Generate the Gsplats according to the mapping that need updating - accumulator.ensureGenerate(maxSplats); - accumulator.splats.splatEncoding = { ...this.splatEncoding }; - const generated = accumulator.generateSplats({ - renderer: this.renderer, - modifier: this.modifier, - generators: newGenerators, - forceUpdate: originChanged, - originToWorld, - }); - - // Update splat version number - accumulator.splatsVersion = this.active.splatsVersion + 1; - // Increment the mapping version if the mapping isn't identical to before - const hasCorrespondence = accumulator.hasCorrespondence(this.active); - accumulator.mappingVersion = - this.active.mappingVersion + (hasCorrespondence ? 0 : 1); - - // Release the old accumulator and make the new one active - this.releaseAccumulator(this.active); - this.active = accumulator; - this.prepareViewpoint(); - } - - // Let the system breath before potentially triggering sorts - setTimeout(() => { - // Notify all auto-updating viewpoints that we updated the Gsplats - for (const view of this.autoViewpoints) { - view.autoPoll({ accumulator: accumulator ?? undefined }); - } - }, 1); - - // Signal update was performed - return true; - } - - private compileScene(scene: THREE.Scene): { - generators: SplatGenerator[]; - visibleGenerators: SplatGenerator[]; - globalEdits: SplatEdit[]; - } { - // Take a snapshot of the SplatGenerators and SplatEdits in the scene - // to be used to run an update. - const generators: SplatGenerator[] = []; - // Collect all SplatGenerators, even if not visible, because we want to - // be able to call their update functions every frame. - scene.traverse((node) => { - if (node instanceof SplatGenerator) { - generators.push(node); - } - }); - - const visibleGenerators: SplatGenerator[] = []; - scene.traverseVisible((node) => { - if (node instanceof SplatGenerator) { - visibleGenerators.push(node); - } - }); - - const globalEdits = new Set(); - scene.traverseVisible((node) => { - if (node instanceof SplatEdit) { - let ancestor = node.parent; - while (ancestor != null && !(ancestor instanceof SplatMesh)) { - ancestor = ancestor.parent; - } - if (ancestor == null) { - // Not part of a SplatMesh so it's a global edit - globalEdits.add(node); - } - } - }); - return { - generators, - visibleGenerators, - globalEdits: Array.from(globalEdits), - }; - } - - // Renders out the scene to an environment map that can be used for - // Image-based lighting or similar applications. First optionally updates Gsplats, - // sorts them with respect to the provided worldCenter, renders 6 cube faces, - // then pre-filters them using THREE.PMREMGenerator and returns a THREE.Texture - // that can assigned directly to a THREE.MeshStandardMaterial.envMap property. - async renderEnvMap({ - renderer, - scene, - worldCenter, - size = 256, - near = 0.1, - far = 1000, - hideObjects = [], - update = false, - }: { - renderer?: THREE.WebGLRenderer; - scene: THREE.Scene; - worldCenter: THREE.Vector3; - size?: number; - near?: number; - far?: number; - hideObjects?: THREE.Object3D[]; - update?: boolean; - }): Promise { - if (!this.envViewpoint) { - this.envViewpoint = this.newViewpoint({ sort360: true }); - } - if ( - !SparkRenderer.cubeRender || - SparkRenderer.cubeRender.target.width !== size || - SparkRenderer.cubeRender.near !== near || - SparkRenderer.cubeRender.far !== far - ) { - if (SparkRenderer.cubeRender) { - SparkRenderer.cubeRender.target.dispose(); - } - const target = new THREE.WebGLCubeRenderTarget(size, { - format: THREE.RGBAFormat, - generateMipmaps: true, - minFilter: THREE.LinearMipMapLinearFilter, - }); - const camera = new THREE.CubeCamera(near, far, target); - SparkRenderer.cubeRender = { target, camera, near, far }; - } - - if (!SparkRenderer.pmrem) { - SparkRenderer.pmrem = new THREE.PMREMGenerator(renderer ?? this.renderer); - } - - // Prepare the viewpoint, sorting Gsplats for this view origin. - const viewToWorld = new THREE.Matrix4().setPosition(worldCenter); - await this.envViewpoint?.prepare({ scene, viewToWorld, update }); - - const { target, camera } = SparkRenderer.cubeRender; - camera.position.copy(worldCenter); - - // Save the visibility state of objects we want to hide before render - const objectVisibility = new Map(); - for (const object of hideObjects) { - objectVisibility.set(object, object.visible); - object.visible = false; - } - - // Update the CubeCamera, which performs 6 cube face renders - this.prepareViewpoint(this.envViewpoint); - camera.update(renderer ?? this.renderer, scene); - - // Restore viewpoint to default and object visibility - this.prepareViewpoint(this.defaultView); - for (const [object, visible] of objectVisibility.entries()) { - object.visible = visible; - } - - // Pre-filter the cube map using THREE.PMREMGenerator - return SparkRenderer.pmrem?.fromCubemap(target.texture).texture; - } - - // Utility function to recursively set the envMap property for any - // THREE.MeshStandardMaterial within the subtree of root. - recurseSetEnvMap(root: THREE.Object3D, envMap: THREE.Texture) { - root.traverse((node) => { - if (node instanceof THREE.Mesh) { - if (Array.isArray(node.material)) { - for (const material of node.material) { - if (material instanceof THREE.MeshStandardMaterial) { - material.envMap = envMap; - } - } - } else { - if (node.material instanceof THREE.MeshStandardMaterial) { - node.material.envMap = envMap; - } - } - } - }); - } - - // Utility function that helps extract the Gsplat RGBA values from a - // SplatGenerator, including the result of any real-time RGBA SDF edits applied - // to a SplatMesh. This effectively "bakes" any computed RGBA values, which can - // now be used as a pipeline input via SplatMesh.splatRgba to inject these - // baked values into the Gsplat data. - getRgba({ - generator, - rgba, - }: { generator: SplatGenerator; rgba?: RgbaArray }): RgbaArray { - const mapping = this.active.mapping.find(({ node }) => node === generator); - if (!mapping) { - throw new Error("Generator not found"); - } - - rgba = rgba ?? new RgbaArray(); - rgba.fromPackedSplats({ - packedSplats: this.active.splats, - base: mapping.base, - count: mapping.count, - renderer: this.renderer, - }); - return rgba; - } - - // Utility function that builds on getRgba({ generator }) and additionally - // reads back the RGBA values to the CPU in a Uint8Array with packed RGBA - // in that byte order. - async readRgba({ - generator, - rgba, - }: { generator: SplatGenerator; rgba?: RgbaArray }): Promise { - rgba = this.getRgba({ generator, rgba }); - return rgba.read(); - } -} - -const EMPTY_GEOMETRY = new SplatGeometry(new Uint32Array(1), 0); - -const reorderSplats = dynoBlock( - { packedSplats: TPackedSplats, index: "int" }, - { gsplat: Gsplat }, - ({ packedSplats, index }) => { - if (!packedSplats || !index) { - throw new Error("Invalid input"); - } - const gsplat = readPackedSplat(packedSplats, index); - return { gsplat }; - }, -); - -function averageOriginToWorlds( - originToWorlds: THREE.Matrix4[], -): THREE.Matrix4 | null { - if (originToWorlds.length === 0) { - return null; - } - - const position = new THREE.Vector3(); - const quaternion = new THREE.Quaternion(); - const scale = new THREE.Vector3(); - - const positions: THREE.Vector3[] = []; - const quaternions: THREE.Quaternion[] = []; - for (const matrix of originToWorlds) { - matrix.decompose(position, quaternion, scale); - positions.push(position); - quaternions.push(quaternion); - } - - return new THREE.Matrix4().compose( - averagePositions(positions), - averageQuaternions(quaternions), - new THREE.Vector3(1, 1, 1), - ); -} diff --git a/src/SparkViewpoint.ts b/src/SparkViewpoint.ts deleted file mode 100644 index 43d960c..0000000 --- a/src/SparkViewpoint.ts +++ /dev/null @@ -1,878 +0,0 @@ -import * as THREE from "three"; - -import { DynoPackedSplats } from "./PackedSplats"; -import { Readback } from "./Readback"; -import type { SparkRenderer } from "./SparkRenderer"; -import type { SplatAccumulator } from "./SplatAccumulator"; -import { SplatGeometry } from "./SplatGeometry"; -import { - type DynoBlock, - DynoBool, - DynoFloat, - type DynoVal, - DynoVec3, - Gsplat, - add, - combine, - defineGsplat, - dyno, - dynoBlock, - dynoConst, - floatBitsToUint, - mul, - packHalf2x16, - readPackedSplat, - uintToRgba8, - unindent, - unindentLines, -} from "./dyno"; -import { withWorker } from "./splatWorker"; -import { FreeList, withinCoorientDist } from "./utils"; - -export type SparkViewpointOptions = { - /** - * Controls whether to auto-update its sort order whenever the SparkRenderer - * updates the Gsplats. If you expect to render/display from this viewpoint - * most frames, set this to true. - * @default false - */ - autoUpdate?: boolean; - /** - * Set a THREE.Camera for this viewpoint to follow. - * @default undefined - */ - camera?: THREE.Camera; - /** - * Set an explicit view-to-world transformation matrix for this viewpoint (equivalent - * to camera.matrixWorld), overrides any camera setting. - * @default undefined - */ - viewToWorld?: THREE.Matrix4; - /** - * Configure viewpoint with an off-screen render target. - * @default undefined - */ - target?: { - /** - * Width of the render target in pixels. - */ - width: number; - /** - * Height of the render target in pixels. - */ - height: number; - /** - * If you want to be able to render a scene that depends on this target's - * output (for example, a recursive viewport), set this to true to enable - * double buffering. - * @default false - */ - doubleBuffer?: boolean; - /** - * Super-sampling factor for the render target. Values 1-4 are supported. - * Note that re-sampling back down to .width x .height is done on the CPU - * with simple averaging only when calling readTarget(). - * @default 1 - */ - superXY?: number; - }; - /** - * Callback function that is called when the render target texture is updated. - * Receives the texture as a parameter. Use this to update a viewport with - * the latest viewpoint render each frame. - * @default undefined - */ - onTextureUpdated?: (texture: THREE.Texture) => void; - /** - * Whether to sort splats radially (geometric distance) from the viewpoint (true) - * or by Z-depth (false). Most scenes are trained with the Z-depth sort metric - * and will render more accurately at certain viewpoints. However, radial sorting - * is more stable under viewpoint rotations. - * @default true - */ - sortRadial?: boolean; - /** - * Distance threshold for re-sorting splats. If the viewpoint moves more than - * this distance, splats will be re-sorted. - * @default 0.01 units - */ - sortDistance?: number; - /** - * View direction dot product threshold for re-sorting splats. For - * sortRadial: true we use 0.99 while sortRadial: false uses 0.999 because it is - * more sensitive to view direction. - * @default 0.99 if sortRadial else 0.999 - */ - sortCoorient?: boolean; - /** - * Constant added to Z-depth to bias values into the positive range for - * sortRadial: false, but also used for culling Gsplats "well behind" - * the viewpoint origin - * @default 1.0 - */ - depthBias?: number; - /** - * Set this to true if rendering a 360 to disable "behind the viewpoint" - * culling during sorting. This is set automatically when rendering 360 envMaps - * using the SparkRenderer.renderEnvMap() utility function. - * @default false - */ - sort360?: boolean; - /* - * Set this to true to sort with float32 precision with two-pass sort. - * @default true - */ - sort32?: boolean; - /* - * Set this to true to enable sort-free stochastic splat rendering. - * @default false - */ - stochastic?: boolean; -}; - -// A SparkViewpoint is created from and tied to a SparkRenderer, and represents -// an independent viewpoint of all the scene Gsplats and their sort order. Making -// these viewpoints explicit allows us to have multiple, simultaneous viewpoint -// renders, for example for camera preview panes or overhead map views. -// -// When creating a SparkRenderer it automatically creates a default viewpoint -// .defaultView that is used in the normal render loop when drawing to the canvas, -// and is automatically updated whenever the camera moves. Additional viewpoints -// can be created and configured separately. - -export class SparkViewpoint { - spark: SparkRenderer; - autoUpdate: boolean; - camera?: THREE.Camera; - viewToWorld: THREE.Matrix4; - lastTime: number | null = null; - - target?: THREE.WebGLRenderTarget; - private back?: THREE.WebGLRenderTarget; - onTextureUpdated?: (texture: THREE.Texture) => void; - encodeLinear = false; - superXY = 1; - private superPixels?: Uint8Array; - private pixels?: Uint8Array; - - sortRadial: boolean; - sortDistance?: number; - sortCoorient?: boolean; - depthBias?: number; - sort360?: boolean; - sort32?: boolean; - stochastic: boolean; - - display: { - accumulator: SplatAccumulator; - viewToWorld: THREE.Matrix4; - geometry: SplatGeometry; - } | null = null; - - private sorting: { viewToWorld: THREE.Matrix4 } | null = null; - private pending: { - accumulator?: SplatAccumulator; - viewToWorld: THREE.Matrix4; - displayed: boolean; - } | null = null; - private sortingCheck = false; - - private readback16: Uint16Array = new Uint16Array(0); - private readback32: Uint32Array = new Uint32Array(0); - private orderingFreelist: FreeList; - - constructor(options: SparkViewpointOptions & { spark: SparkRenderer }) { - this.spark = options.spark; - this.camera = options.camera; - this.viewToWorld = options.viewToWorld ?? new THREE.Matrix4(); - - if (options.target) { - const { width, height, doubleBuffer } = options.target; - const superXY = Math.max(1, Math.min(4, options.target.superXY ?? 1)); - this.superXY = superXY; - if (width * superXY > 8192 || height * superXY > 8192) { - throw new Error("Target size too large"); - } - - this.target = new THREE.WebGLRenderTarget( - width * superXY, - height * superXY, - { - format: THREE.RGBAFormat, - type: THREE.UnsignedByteType, - colorSpace: THREE.SRGBColorSpace, - }, - ); - if (doubleBuffer) { - this.back = new THREE.WebGLRenderTarget( - width * superXY, - height * superXY, - { - format: THREE.RGBAFormat, - type: THREE.UnsignedByteType, - colorSpace: THREE.SRGBColorSpace, - }, - ); - } - this.encodeLinear = true; - } - this.onTextureUpdated = options.onTextureUpdated; - - this.sortRadial = options.sortRadial ?? true; - this.sortDistance = options.sortDistance; - this.sortCoorient = options.sortCoorient; - this.depthBias = options.depthBias; - this.sort360 = options.sort360; - this.sort32 = options.sort32; - this.stochastic = options.stochastic ?? false; - - this.orderingFreelist = new FreeList({ - allocate: (maxSplats) => new Uint32Array(maxSplats), - valid: (ordering, maxSplats) => ordering.length === maxSplats, - }); - - this.autoUpdate = false; - this.setAutoUpdate(options.autoUpdate ?? false); - } - - // Call this when you are done with the SparkViewpoint and want to - // free up its resources (GPU targets, pixel buffers, etc.) - dispose() { - this.setAutoUpdate(false); - if (this.target) { - this.target.dispose(); - this.target = undefined; - } - if (this.back) { - this.back.dispose(); - this.back = undefined; - } - if (this.display) { - this.spark.releaseAccumulator(this.display.accumulator); - this.display.geometry.dispose(); - this.display = null; - } - if (this.pending?.accumulator) { - this.spark.releaseAccumulator(this.pending.accumulator); - this.pending = null; - } - } - - // Use this function to change whether this viewpoint will auto-update - // its sort order whenever the attached SparkRenderer updates the Gsplats. - // Turn this on or off depending on whether you expect to do renders from - // this viewpoint most frames. - setAutoUpdate(autoUpdate: boolean) { - if (!this.autoUpdate && autoUpdate) { - this.spark.autoViewpoints.push(this); - } else if (this.autoUpdate && !autoUpdate) { - this.spark.autoViewpoints = this.spark.autoViewpoints.filter( - (v) => v !== this, - ); - } - this.autoUpdate = autoUpdate; - } - - // See below async prepareRenderPixels() for explanation of parameters. - // Awaiting this method updates the Gsplats in the scene and performs a sort of the - // Gsplats from this viewpoint, preparing it for a subsequent this.renderTarget() - // call in the same tick. - async prepare({ - scene, - camera, - viewToWorld, - update, - forceOrigin, - }: { - scene: THREE.Scene; - camera?: THREE.Camera; - viewToWorld?: THREE.Matrix4; - update?: boolean; - forceOrigin?: boolean; - }) { - if (viewToWorld) { - this.viewToWorld = viewToWorld; - } else { - this.camera = camera ?? this.camera; - if (this.camera) { - this.camera.updateMatrixWorld(); - this.viewToWorld = this.camera.matrixWorld.clone(); - } - } - while (update ?? true) { - // Force an update, possibly with origin centered at this camera - // to yield the best quality output. - const originToWorld = forceOrigin - ? this.viewToWorld - : this.spark.matrixWorld; - const updated = this.spark.updateInternal({ scene, originToWorld }); - if (updated) { - break; - } - // A bit of a hack, but try again. We shouldn't be starved for long. - await new Promise((resolve) => setTimeout(resolve, 10)); - } - - const accumulator = this.spark.active; - // Hold reference to accumulator while sorting - accumulator.refCount += 1; - await this.sortUpdate({ accumulator, viewToWorld: this.viewToWorld }); - // Release accumulator reference - this.spark.releaseAccumulator(accumulator); - } - - // Render out the viewpoint to the view target RGBA buffer. - // Swaps buffers if doubleBuffer: true was set. - // Calls onTextureUpdated(texture) with the resulting texture. - renderTarget({ - scene, - camera, - }: { scene: THREE.Scene; camera?: THREE.Camera }) { - const target = this.back ?? this.target; - if (!target) { - throw new Error("Must initialize SparkViewpoint with target"); - } - - camera = camera ?? this.camera; - if (!camera) { - throw new Error("Must provide camera"); - } - if (camera instanceof THREE.PerspectiveCamera) { - const newCam = new THREE.PerspectiveCamera().copy(camera, false); - newCam.aspect = target.width / target.height; - newCam.updateProjectionMatrix(); - camera = newCam; - } - this.viewToWorld = camera.matrixWorld.clone(); - - try { - this.spark.renderer.setRenderTarget(target); - this.spark.prepareViewpoint(this); - - this.spark.renderer.render(scene, camera); - } finally { - this.spark.prepareViewpoint(this.spark.defaultView); - this.spark.renderer.setRenderTarget(null); - } - - if (target !== this.target) { - // Swap back buffer and target - [this.target, this.back] = [this.back, this.target]; - } - this.onTextureUpdated?.(target.texture); - } - - // Read back the previously rendered target image as a Uint8Array of packed - // RGBA values (in that order). If superXY was set greater than 1 then - // downsampling is performed in the target pixel array with simple averaging - // to derive the returned pixel values. Subsequent calls to this.readTarget() - // will reuse the same buffers to minimize memory allocations. - async readTarget(): Promise { - if (!this.target) { - throw new Error("Must initialize SparkViewpoint with target"); - } - const { width, height } = this.target; - const byteSize = width * height * 4; - if (!this.superPixels || this.superPixels.length < byteSize) { - this.superPixels = new Uint8Array(byteSize); - } - await this.spark.renderer.readRenderTargetPixelsAsync( - this.target, - 0, - 0, - width, - height, - this.superPixels, - ); - - const { superXY } = this; - if (superXY === 1) { - return this.superPixels; - } - - const subWidth = width / superXY; - const subHeight = height / superXY; - const subSize = subWidth * subHeight * 4; - if (!this.pixels || this.pixels.length < subSize) { - this.pixels = new Uint8Array(subSize); - } - - const { superPixels, pixels } = this; - const super2 = superXY * superXY; - for (let y = 0; y < subHeight; y++) { - const row = y * subWidth; - for (let x = 0; x < subWidth; x++) { - const superCol = x * superXY; - let r = 0; - let g = 0; - let b = 0; - let a = 0; - for (let sy = 0; sy < superXY; sy++) { - const superRow = (y * superXY + sy) * this.target.width; - for (let sx = 0; sx < superXY; sx++) { - const superIndex = (superRow + superCol + sx) * 4; - r += superPixels[superIndex]; - g += superPixels[superIndex + 1]; - b += superPixels[superIndex + 2]; - a += superPixels[superIndex + 3]; - } - } - const pixelIndex = (row + x) * 4; - pixels[pixelIndex] = r / super2; - pixels[pixelIndex + 1] = g / super2; - pixels[pixelIndex + 2] = b / super2; - pixels[pixelIndex + 3] = a / super2; - } - } - return pixels; - } - - // Render out a viewpoint as a Uint8Array of RGBA values for the provided scene - // and any camera/viewToWorld viewpoint overrides. By default update is true, - // which triggers its SparkRenderer to check and potentially update the Gsplats. - // Setting update to false disables this and sorts the Gsplats as they are. - // Setting forceOrigin (default: false) to true forces the view update to - // recalculate the splats with this view origin, potentially altering any - // view-dependent effects. If you expect view-dependent effects to play a role - // in the rendering quality, enable this. - // - // Underneath, prepareRenderPixels() simply calls await this.prepare(...), - // this.renderTarget(...), and finally returns the result this.readTarget(), - // a Promise to a Uint8Array with RGBA values for all the pixels (potentially - // downsampled if the superXY parameter was used). These steps can also be called - // manually, for example if you need to alter the scene before and after - // this.renderTarget(...) to hide UI elements from being rendered. - async prepareRenderPixels({ - scene, - camera, - viewToWorld, - update, - forceOrigin, - }: { - scene: THREE.Scene; - camera?: THREE.Camera; - viewToWorld?: THREE.Matrix4; - update?: boolean; - forceOrigin?: boolean; - }) { - await this.prepare({ scene, camera, viewToWorld, update, forceOrigin }); - this.renderTarget({ scene, camera }); - return this.readTarget(); - } - - // This is called automatically by SparkRenderer, there is no need to call it! - // The method cannot be private because then SparkRenderer would - // not be able to call it. - autoPoll({ accumulator }: { accumulator?: SplatAccumulator }) { - if (this.camera) { - this.camera.updateMatrixWorld(); - this.viewToWorld = this.camera.matrixWorld.clone(); - } - - let needsSort = false; - let displayed = false; - - if (!this.display) { - // Need to do first sort - needsSort = true; - } else if (accumulator) { - needsSort = true; - const { mappingVersion } = this.display.accumulator; - if (accumulator.mappingVersion === mappingVersion) { - // Splat mapping has not changed, so reuse the existing sorted - // geometry to show updates faster. We will still fire off - // a re-sort if necessary. First release old accumulator. - accumulator.refCount += 1; - this.spark.releaseAccumulator(this.display.accumulator); - this.display.accumulator = accumulator; - this.display.viewToWorld.copy(this.viewToWorld); - displayed = true; - - if (this.spark.viewpoint === this) { - this.spark.prepareViewpoint(this); - } - } - } - - const latestView = this.sorting?.viewToWorld ?? this.display?.viewToWorld; - if ( - latestView && - !withinCoorientDist({ - matrix1: this.viewToWorld, - matrix2: latestView, - // By default update sort each 1 cm - maxDistance: this.sortDistance ?? 0.01, - // By default for radial sort, update for intermittent movement so that - // we bring back splats culled by being behind the camera. - // For depth sort, small rotations can change sort order a lot, so - // update sort for even small rotations. - minCoorient: (this.sortCoorient ?? this.sortRadial) ? 0.99 : 0.999, - }) - ) { - needsSort = true; - } - - if (!needsSort) { - // Stop here, no sort necessary - return; - } - - if (accumulator) { - // Hold a reference to the accumulator for sorting - accumulator.refCount += 1; - } - - if (this.pending?.accumulator) { - // Release the reference of the pending accumulator - this.spark.releaseAccumulator(this.pending.accumulator); - } - this.pending = { accumulator, viewToWorld: this.viewToWorld, displayed }; - - // Don't await this, just trigger the sort if necessary - this.driveSort(); - } - - private async driveSort() { - while (true) { - if (this.sorting || !this.pending) { - return; // Sort already in process or nothing to sort - } - - const { viewToWorld, displayed } = this.pending; - let accumulator = this.pending.accumulator; - if (!accumulator) { - // Hold a reference to the accumulator while sorting - accumulator = this.display?.accumulator ?? this.spark.active; - accumulator.refCount += 1; - } - this.pending = null; - if (!accumulator) { - throw new Error("No accumulator to sort"); - } - - this.sorting = { viewToWorld }; - await this.sortUpdate({ accumulator, viewToWorld, displayed }); - this.sorting = null; - - // Release the reference to the accumulator - this.spark.releaseAccumulator(accumulator); - - // Continue in loop with any queued sort - } - } - - private async sortUpdate({ - accumulator, - viewToWorld, - displayed = false, - }: { - accumulator?: SplatAccumulator; - viewToWorld: THREE.Matrix4; - displayed?: boolean; - }) { - if (this.sortingCheck) { - throw new Error("Only one sort at a time"); - } - this.sortingCheck = true; - - accumulator = accumulator ?? this.spark.active; - const { numSplats, maxSplats } = accumulator.splats; - let activeSplats = 0; - let ordering = this.orderingFreelist.alloc(maxSplats); - - if (this.stochastic) { - activeSplats = numSplats; - // Render all splats in order since the Z-buffer - // will handle ordering. - for (let i = 0; i < numSplats; ++i) { - ordering[i] = i; - } - } else if (numSplats > 0) { - const { - reader, - doubleSortReader, - sort32Reader, - dynoSortRadial, - dynoOrigin, - dynoDirection, - dynoDepthBias, - dynoSort360, - dynoSplats, - } = SparkViewpoint.makeSorter(); - const sort32 = this.sort32 ?? false; - let readback: Uint16Array | Uint32Array; - if (sort32) { - this.readback32 = reader.ensureBuffer(maxSplats, this.readback32); - readback = this.readback32; - } else { - const halfMaxSplats = Math.ceil(maxSplats / 2); - this.readback16 = reader.ensureBuffer(halfMaxSplats, this.readback16); - readback = this.readback16; - } - - const worldToOrigin = accumulator.toWorld.clone().invert(); - const viewToOrigin = viewToWorld.clone().premultiply(worldToOrigin); - - dynoSortRadial.value = this.sort360 ? true : this.sortRadial; - dynoOrigin.value.set(0, 0, 0).applyMatrix4(viewToOrigin); - dynoDirection.value - .set(0, 0, -1) - .applyMatrix4(viewToOrigin) - .sub(dynoOrigin.value) - .normalize(); - dynoDepthBias.value = this.depthBias ?? 1.0; - dynoSort360.value = this.sort360 ?? false; - dynoSplats.packedSplats = accumulator.splats; - - const sortReader = sort32 ? sort32Reader : doubleSortReader; - const count = sort32 ? numSplats : Math.ceil(numSplats / 2); - await reader.renderReadback({ - renderer: this.spark.renderer, - reader: sortReader, - count, - readback, - }); - - const result = (await withWorker(async (worker) => { - const rpcName = sort32 ? "sort32Splats" : "sortDoubleSplats"; - return worker.call(rpcName, { - maxSplats, - numSplats, - readback, - ordering, - }); - })) as { - readback: Uint16Array | Uint32Array; - ordering: Uint32Array; - activeSplats: number; - }; - if (sort32) { - this.readback32 = result.readback as Uint32Array; - } else { - this.readback16 = result.readback as Uint16Array; - } - ordering = result.ordering; - activeSplats = result.activeSplats; - } - - this.updateDisplay({ - accumulator, - viewToWorld, - ordering, - activeSplats, - displayed, - }); - this.sortingCheck = false; - } - - private updateDisplay({ - accumulator, - viewToWorld, - ordering, - activeSplats, - displayed = false, - }: { - accumulator: SplatAccumulator; - viewToWorld: THREE.Matrix4; - ordering: Uint32Array; - activeSplats: number; - displayed?: boolean; - }) { - if (!this.display) { - // Hold a reference to the accumulator while part of display - accumulator.refCount += 1; - this.display = { - accumulator, - viewToWorld, - geometry: new SplatGeometry(ordering, activeSplats), - }; - } else { - if (!displayed && accumulator !== this.display.accumulator) { - // Hold a reference to the new accumulator being displayed - accumulator.refCount += 1; - // Release the reference to the previously displayed accumulator - this.spark.releaseAccumulator(this.display.accumulator); - this.display.accumulator = accumulator; - } - - this.display.viewToWorld = viewToWorld; - - const oldOrdering = this.display.geometry.ordering; - if (oldOrdering.length === ordering.length) { - this.display.geometry.update(ordering, activeSplats); - } else { - this.display.geometry.dispose(); - // console.log("*** alloc SplatGeometry", ordering.length); - this.display.geometry = new SplatGeometry(ordering, activeSplats); - } - this.orderingFreelist.free(oldOrdering); - } - if (this.spark.viewpoint === this) { - this.spark.prepareViewpoint(this); - } - } - - // If you need an empty THREE.Texture to use to initialize a uniform that is - // updated via onTextureUpdated(texture), this static texture can be handy. - static EMPTY_TEXTURE = new THREE.Texture(); - - private static dynos: { - dynoSortRadial: DynoBool; - dynoOrigin: DynoVec3; - dynoDirection: DynoVec3; - dynoDepthBias: DynoFloat; - dynoSort360: DynoBool; - dynoSplats: DynoPackedSplats; - reader: Readback; - doubleSortReader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; - sort32Reader: DynoBlock<{ index: "int" }, { rgba8: "vec4" }>; - } | null = null; - - private static makeSorter() { - if (!SparkViewpoint.dynos) { - const dynoSortRadial = new DynoBool({ value: true }); - const dynoOrigin = new DynoVec3({ value: new THREE.Vector3() }); - const dynoDirection = new DynoVec3({ value: new THREE.Vector3() }); - const dynoDepthBias = new DynoFloat({ value: 1.0 }); - const dynoSort360 = new DynoBool({ value: false }); - const dynoSplats = new DynoPackedSplats(); - - const reader = new Readback(); - const doubleSortReader = dynoBlock( - { index: "int" }, - { rgba8: "vec4" }, - ({ index }) => { - if (!index) { - throw new Error("No index"); - } - const sortParams = { - sortRadial: dynoSortRadial, - sortOrigin: dynoOrigin, - sortDirection: dynoDirection, - sortDepthBias: dynoDepthBias, - sort360: dynoSort360, - }; - const index2 = mul(index, dynoConst("int", 2)); - - const gsplat0 = readPackedSplat(dynoSplats, index2); - const metric0 = computeSortMetric({ gsplat: gsplat0, ...sortParams }); - - const gsplat1 = readPackedSplat( - dynoSplats, - add(index2, dynoConst("int", 1)), - ); - const metric1 = computeSortMetric({ gsplat: gsplat1, ...sortParams }); - - const combined = combine({ - vectorType: "vec2", - x: metric0, - y: metric1, - }); - const rgba8 = uintToRgba8(packHalf2x16(combined)); - return { rgba8 }; - }, - ); - - const sort32Reader = dynoBlock( - { index: "int" }, - { rgba8: "vec4" }, - ({ index }) => { - if (!index) { - throw new Error("No index"); - } - const sortParams = { - sortRadial: dynoSortRadial, - sortOrigin: dynoOrigin, - sortDirection: dynoDirection, - sortDepthBias: dynoDepthBias, - sort360: dynoSort360, - }; - - const gsplat = readPackedSplat(dynoSplats, index); - const metric = computeSortMetric({ gsplat, ...sortParams }); - const rgba8 = uintToRgba8(floatBitsToUint(metric)); - return { rgba8 }; - }, - ); - - SparkViewpoint.dynos = { - dynoSortRadial, - dynoOrigin, - dynoDirection, - dynoDepthBias, - dynoSort360, - dynoSplats, - reader, - doubleSortReader, - sort32Reader, - }; - } - return SparkViewpoint.dynos; - } -} - -const defineComputeSortMetric = unindent(` - float computeSort(Gsplat gsplat, bool sortRadial, vec3 sortOrigin, vec3 sortDirection, float sortDepthBias, bool sort360) { - if (!isGsplatActive(gsplat.flags)) { - return INFINITY; - } - - vec3 center = gsplat.center - sortOrigin; - float biasedDepth = dot(center, sortDirection) + sortDepthBias; - if (!sort360 && (biasedDepth <= 0.0)) { - return INFINITY; - } - - return sortRadial ? length(center) : biasedDepth; - } -`); - -function computeSortMetric({ - gsplat, - sortRadial, - sortOrigin, - sortDirection, - sortDepthBias, - sort360, -}: { - gsplat: DynoVal; - sortRadial: DynoVal<"bool">; - sortOrigin: DynoVal<"vec3">; - sortDirection: DynoVal<"vec3">; - sortDepthBias: DynoVal<"float">; - sort360: DynoVal<"bool">; -}) { - return dyno({ - inTypes: { - gsplat: Gsplat, - sortRadial: "bool", - sortOrigin: "vec3", - sortDirection: "vec3", - sortDepthBias: "float", - sort360: "bool", - }, - outTypes: { metric: "float" }, - globals: () => [defineGsplat, defineComputeSortMetric], - inputs: { - gsplat, - sortRadial, - sortOrigin, - sortDirection, - sortDepthBias, - sort360, - }, - statements: ({ inputs, outputs }) => { - const { - gsplat, - sortRadial, - sortOrigin, - sortDirection, - sortDepthBias, - sort360, - } = inputs; - return unindentLines(` - ${outputs.metric} = computeSort(${gsplat}, ${sortRadial}, ${sortOrigin}, ${sortDirection}, ${sortDepthBias}, ${sort360}); - `); - }, - }).outputs.metric; -} diff --git a/src/Splat.ts b/src/Splat.ts new file mode 100644 index 0000000..ae7a76c --- /dev/null +++ b/src/Splat.ts @@ -0,0 +1,701 @@ +import * as THREE from "three"; +import { SplatGeometry } from "./SplatGeometry"; +import { + ReadbackSplatSorter, + type SplatOrdering, + type SplatSorter, +} from "./SplatSorter"; +import type { TransformRange } from "./defines"; +import { simpleRaycastMethod } from "./raycast"; +import { getShaders } from "./shaders"; +import { withinCoorientDist } from "./utils"; + +/** + * Global counter used to generate unique ordering IDs + */ +let globalOrderingId = 0; + +/** + * Interface providing the properties for the individual splats. + * The data is considered to be read-only for the purpose of this interface. + */ +export interface SplatData { + /** + * The maximum number of splats this SplatData can hold. + */ + readonly maxSplats: number; + /** + * The actual number of splats, must be less than maxSplats. + */ + readonly numSplats: number; + /** + * The number of spherical harmonic degrees for each splat. + */ + readonly numSh: number; + + /** + * Adjusts a given ShaderMaterial so it can read from this SplatData. + * This generally includes setting up the right uniforms and shader chunks. + * @param material The material to setup. + */ + setupMaterial: (material: THREE.ShaderMaterial) => void; + + /** + * Method for iterating over the raw splat centers. + * @param callback Callback to call with each center + */ + iterateCenters: ( + callback: (index: number, x: number, y: number, z: number) => void, + ) => void; + + /** + * Dispose any resources + */ + dispose(): void; +} + +export type SplatCallback = ( + i: number, + x: number, + y: number, + z: number, + scaleX: number, + scaleY: number, + scaleZ: number, + quatX: number, + quatY: number, + quatZ: number, + quatW: number, + opacity: number, + r: number, + g: number, + b: number, + sh?: ArrayLike, +) => void; + +/** + * Extended SplatData interface that allows the splat properties + * to be decoded and read back. + */ +export interface IterableSplatData extends SplatData { + iterateSplats: (callback: SplatCallback) => void; +} + +export type SortContext = { + lastOriginToCamera: THREE.Matrix4; + lastWorldTransform: THREE.Matrix4; + sortJob: Promise | null; + ordering: Uint32Array; + pendingOrdering: Uint32Array; + activeSplats: number; + orderingId: number; + splatVersion: number; +}; + +export interface SplatOptions { + sorter?: SplatSorter; + premultipliedAlpha?: boolean; +} + +/** + * Object representing a collection of Gaussian Splats in a scene. + */ +export class Splat extends THREE.Mesh { + /** + * The underlying splat data. + */ + readonly splatData: SplatData; + + /** + * Collection of (shared) uniforms for the splat shader. + * Additional uniforms might be provided by the SplatData for de-/encoding + * and by user-provided shader hooks. + */ + private readonly uniforms: ReturnType; + + /** + * Set of user-provided shader hooks to + */ + private shaderHooks: ShaderHooks | null = null; + + /** + * The sort implementation to use to sort the splats. + * Only used when stochastic flag is false. + */ + readonly sorter: SplatSorter; + /** + * Mapping from camera to sort context. + * This allows multiple viewpoints from different cameras. + */ + private readonly sortContext: WeakMap = + new WeakMap(); + /** + * Id of the current ordering used by the SplatGeometry. + */ + private currentOrderingId = -1; + + /** + * The current version of the splat. While splat data is intended to be + * static, changes through shader hooks or transform ranges can require + * sorting to performed again. + */ + private splatVersion = 0; + + /** + * The raycast method to use when raycasting against this splat object. + * Initial tests against bounding sphere and box take place, regardless of + * the chosen method. + */ + raycastMethod: + | (( + splat: Splat, + raycaster: THREE.Raycaster, + intersects: THREE.Intersection[], + ) => void) + | null = simpleRaycastMethod; + + /** + * Maximum standard deviations from the center to render Gaussians. + * Values `Math.sqrt(5)..Math.sqrt(9)` produce good results and can be tweaked for performance. + * @default Math.sqrt(8) + */ + maxStdDev: number = Math.sqrt(8); + /** + * Minimum pixel radius for splat rendering. + * @default 0.0 + */ + minPixelRadius = 0; + /** + * Maximum pixel radius for splat rendering. + * @default 512.0 + */ + maxPixelRadius = 512; + /** + * Minimum alpha value for splat rendering. + * @default 0.5 / 255.0 + */ + minAlpha: number = 0.5 / 255.0; + preBlurAmount = 0.0; + blurAmount = 0.3; + falloff = 1.0; + clipXY = 1.4; + focalAdjustment = 2.0; + + /** + * Maximum Spherical Harmonics level to use. Spark supports up to SH3. + * + * @default 3 + */ + maxSh = 3; + + /** + * Whether or not sorting should happen automatically. + * @default true + */ + autoSort = true; + /** + * Whether or not to use sort-free stochastic rendering. + * @default false + */ + stochastic = false; + enable2DGS = false; + + /** + * Distance threshold in world units for re-sorting splats. + * If the viewpoint moves more than this distance, splats will be re-sorted. + * @default 0.01 + */ + sortDistance = 0.01; + /** + * View direction dot product threshold for re-sorting splats. For `sortRadial: true` + * it defaults to 0.99 while `sortRadial: false` uses 0.999 because it is more + * sensitive to view direction. + * @default 0.99 if sortRadial else 0.999 + */ + sortCoorient = 0.999; // FIXME: Depend on sortRadial :-/ + + constructor(splatData: SplatData, options: SplatOptions = {}) { + const uniforms = Splat.makeUniforms(); + const shaders = getShaders(); + const premultipliedAlpha = options.premultipliedAlpha ?? true; + const material = new THREE.ShaderMaterial({ + name: "SplatShader", + glslVersion: THREE.GLSL3, + vertexShader: shaders.splatVertex, + fragmentShader: shaders.splatFragment, + uniforms, + premultipliedAlpha, + transparent: true, + depthTest: true, + depthWrite: false, + side: THREE.DoubleSide, + defines: { + STOCHASTIC: false, + SPLAT_DECODE_FN: "", + SPLAT_SH_DECODE_FN: "", + NUM_SH: 0, + }, + }); + + super(new SplatGeometry(), material); + + // Use a high render order to ensure being rendered at the end of the transparent queue. + this.renderOrder = 9999; + this.frustumCulled = false; + + this.uniforms = uniforms; + + this.sorter = options.sorter ?? new ReadbackSplatSorter(); + + this.splatData = splatData; + this.splatData.setupMaterial(material); + + this.geometry.updateBounds(this.splatData); + } + + onBeforeRender( + renderer: THREE.WebGLRenderer, + scene: THREE.Scene, + camera: THREE.Camera, + ) { + // Keep track of the camera to use for sorting. + // Generally this is the same camera as used for rendering, though during + // WebXR sessions this will be the XRCamera instead. + let sortCamera = camera; + + // During immersive WebXR sessions this method can be called multiple times. + // Only act on the first one. + if (renderer.xr.isPresenting) { + const xrCamera = renderer.xr.getCamera(); + const cameraIndex = xrCamera.cameras.indexOf(camera as THREE.WebXRCamera); + if (cameraIndex === 0) { + // First camera, use the main xrCamera for sorting. + sortCamera = xrCamera; + } else if (cameraIndex > 0) { + // This is not the first camera (index 0) nor a different camera (index -1). + // Material should already be prepared and sorting kicked off, nothing to do. + return; + } + } + + const currentRenderTarget = renderer.getRenderTarget(); + if (currentRenderTarget) { + // Rendering to a texture target, so its dimensions + this.uniforms.renderSize.value.set( + currentRenderTarget.width, + currentRenderTarget.height, + ); + } else { + // Rendering to the canvas or WebXR + const renderSize = renderer.getDrawingBufferSize( + this.uniforms.renderSize.value, + ); + if (renderSize.x === 1 && renderSize.y === 1) { + // WebXR mode on Apple Vision Pro returns 1x1 when presenting. + // Use a different means to figure out the render size. + const baseLayer = renderer.xr.getSession()?.renderState.baseLayer; + if (baseLayer) { + renderSize.x = baseLayer.framebufferWidth; + renderSize.y = baseLayer.framebufferHeight; + } + } + } + + // Check for stochastic rendering + if (this.material.defines.STOCHASTIC !== this.stochastic) { + this.material.defines.STOCHASTIC = this.stochastic; + this.material.transparent = !this.stochastic; + this.material.depthWrite = this.stochastic; + this.material.needsUpdate = true; + } + + // Update the number of SH to evaluate + const numSh = Math.min(this.maxSh, this.splatData.numSh); + if (this.material.defines.NUM_SH !== numSh) { + this.material.defines.NUM_SH = numSh; + this.material.needsUpdate = true; + } + + // Update uniforms + this.uniforms.numSplats.value = this.splatData.numSplats; + this.uniforms.maxStdDev.value = this.maxStdDev; + this.uniforms.minPixelRadius.value = this.minPixelRadius; + this.uniforms.maxPixelRadius.value = this.maxPixelRadius; + this.uniforms.minAlpha.value = this.minAlpha; + this.uniforms.time.value = performance.now() / 1000; + this.uniforms.enable2DGS.value = this.enable2DGS; + this.uniforms.preBlurAmount.value = this.preBlurAmount; + this.uniforms.blurAmount.value = this.blurAmount; + this.uniforms.falloff.value = this.falloff; + this.uniforms.clipXY.value = this.clipXY; + this.uniforms.focalAdjustment.value = this.focalAdjustment; + this.uniforms.opacity.value = this.material.opacity; + const outputColorSpace = + currentRenderTarget === null + ? renderer.outputColorSpace + : "isXRRenderTarget" in currentRenderTarget && + currentRenderTarget.isXRRenderTarget === true + ? currentRenderTarget.texture.colorSpace + : THREE.LinearSRGBColorSpace; + this.uniforms.encodeLinear.value = + outputColorSpace !== THREE.SRGBColorSpace; + + // Perform sorting if needed + if (this.autoSort && !this.stochastic) { + this.sortFor(renderer, sortCamera, false); + } + + // Ensure geometry uses the correct order + if (this.stochastic) { + // Ordering does not apply for stochastic and is camera independent + this.geometry.instanceCount = this.splatData.numSplats; + } else { + // Fetch sorting context for this camera + const context = this.sortContext.get(sortCamera); + if (context && this.currentOrderingId !== context.orderingId) { + this.currentOrderingId = context.orderingId; + this.geometry.update(renderer, context.ordering, context.activeSplats); + } + } + } + + async sortFor( + renderer: THREE.WebGLRenderer, + camera: THREE.Camera, + updateOrdering = true, + ) { + // Calculate the transform from the accumulator to the current camera + const worldToCamera = camera.matrixWorld.clone().invert(); + const originToCamera = this.matrixWorld.clone().premultiply(worldToCamera); + + // Check if sorting is needed + let context = this.sortContext.get(camera); + let needsSort = false; + if (!context) { + context = { + lastOriginToCamera: new THREE.Matrix4(), + lastWorldTransform: new THREE.Matrix4(), + sortJob: null, + ordering: new Uint32Array(this.splatData.maxSplats), + pendingOrdering: new Uint32Array(this.splatData.maxSplats), + activeSplats: 0, + orderingId: globalOrderingId++, + splatVersion: this.splatVersion, + }; + needsSort = true; + this.sortContext.set(camera, context); + } + + // Check if the underlying splat version matches + if (context.splatVersion !== this.splatVersion) { + needsSort = true; + } + + // Check if the camera moved, requiring a new sort + if ( + !needsSort && + !withinCoorientDist({ + matrix1: originToCamera, + matrix2: context.lastOriginToCamera, + // By default update sort each 1 cm + maxDistance: this.sortDistance ?? 0.01, + // By default for radial sort, update for intermittent movement so that + // we bring back splats culled by being behind the camera. + // For depth sort, small rotations can change sort order a lot, so + // update sort for even small rotations. + minCoorient: 0.999, // FIXME + }) + ) { + needsSort = true; + } + + if (!this.matrixWorld.equals(context.lastWorldTransform)) { + needsSort = true; + } + + // Prepare next sort when needed + if (needsSort && !context.sortJob) { + context.lastOriginToCamera.copy(originToCamera); + context.lastWorldTransform.copy(this.matrixWorld); + context.splatVersion = this.splatVersion; + context.sortJob = this.sorter.sort( + camera, + this, + renderer, + context.pendingOrdering, + ); + context.sortJob.then((result) => this.onSortComplete(context, result)); + } + + if (context.sortJob) { + await context.sortJob; + } + + if (updateOrdering) { + this.currentOrderingId = context.orderingId; + this.geometry.update(renderer, context.ordering, context.activeSplats); + } + } + + protected onSortComplete(context: SortContext, result: SplatOrdering) { + context.sortJob = null; + // Swap ordering arrays + context.pendingOrdering = context.ordering; + context.ordering = result.ordering; + + context.activeSplats = result.activeSplats; + context.orderingId = globalOrderingId++; + } + + /** + * Returns an array of splat ranges with their corresponding (world) transform. + * This allows rigid transforms to apply to subsets of the splats. + */ + getTransformRanges(): Array { + return [ + { + start: 0, + end: this.splatData.numSplats, + matrix: this.matrixWorld.toArray(), + }, + ]; + } + + setShaderHooks(hooks: ShaderHooks | null): ShaderHooks | null { + const previousShaderHooks = this.shaderHooks; + this.shaderHooks = hooks; + + // Add additional uniforms + if (hooks?.vertex?.uniforms) { + for (const uniform in hooks.vertex.uniforms) { + this.material.uniforms[uniform] = hooks.vertex.uniforms[uniform]; + } + } + if (hooks?.fragment?.uniforms) { + for (const uniform in hooks.fragment.uniforms) { + this.material.uniforms[uniform] = hooks.fragment.uniforms[uniform]; + } + } + + // Prepare compile hook + this.material.onBeforeCompile = (program, renderer) => { + if (!program.defines) { + program.defines = {}; + } + + if (this.shaderHooks?.vertex) { + program.defines.HOOK_UNIFORMS = !!this.shaderHooks.vertex.uniforms; + if (this.shaderHooks.vertex.uniforms) { + // Generate uniform code block + const uniforms = Object.entries(this.shaderHooks.vertex.uniforms) + .map( + (entry) => + `uniform ${entry[1].type} ${entry[0]}${Array.isArray(entry[1].value) ? `[${entry[1].value.length}]` : ""};`, + ) + .join("\n"); + program.vertexShader = program.vertexShader.replace( + "{{HOOK_UNIFORMS}}", + uniforms, + ); + } + program.defines.HOOK_GLOBAL = !!this.shaderHooks.vertex.global; + program.vertexShader = program.vertexShader.replace( + "{{HOOK_GLOBAL}}", + this.shaderHooks.vertex.global ?? "", + ); + program.defines.HOOK_OBJECT_MODIFIER = + !!this.shaderHooks.vertex.objectModifier; + program.vertexShader = program.vertexShader.replace( + "{{HOOK_OBJECT_MODIFIER}}", + this.shaderHooks.vertex.objectModifier ?? "", + ); + program.defines.HOOK_WORLD_MODIFIER = + !!this.shaderHooks.vertex.worldModifier; + program.vertexShader = program.vertexShader.replace( + "{{HOOK_WORLD_MODIFIER}}", + this.shaderHooks.vertex.worldModifier ?? "", + ); + program.defines.HOOK_SPLAT_COLOR = !!this.shaderHooks.vertex.splatColor; + program.vertexShader = program.vertexShader.replace( + "{{HOOK_SPLAT_COLOR}}", + this.shaderHooks.vertex.splatColor ?? "", + ); + } + + if (this.shaderHooks?.fragment) { + if (this.shaderHooks.fragment.uniforms) { + // Generate uniform code block + const uniforms = Object.entries(this.shaderHooks.fragment.uniforms) + .map( + (entry) => + `uniform ${entry[1].type} ${entry[0]}${Array.isArray(entry[1].value) ? `[${entry[1].value.length}]` : ""};`, + ) + .join("\n"); + program.vertexShader = program.vertexShader.replace( + "#define HOOK_UNIFORMS", + uniforms, + ); + } + } + + if (this.shaderHooks?.onBeforeCompile) { + this.shaderHooks.onBeforeCompile(program, renderer); + } + }; + + // Material is specific to instance. + // FIXME: Maybe hash the shader hooks struct? + this.material.customProgramCacheKey = () => this.uuid; + + // Make sure the material recompiles + this.material.needsUpdate = true; + + return previousShaderHooks; + } + + get opacity(): number { + return this.material.opacity; + } + + set opacity(value: number) { + this.material.opacity = value; + } + + dispose() { + this.geometry.dispose(); + this.material.dispose(); + this.splatData.dispose(); + } + + set needsUpdate(value: boolean) { + if (value === true) this.splatVersion++; + } + + // NOTE: Override _computeIntersections to allow base implementation of THREE.Mesh to check + // against bounding sphere and bounding box. + _computeIntersections( + raycaster: THREE.Raycaster, + intersects: THREE.Intersection[], + ): void { + if (this.raycastMethod) { + this.raycastMethod(this, raycaster, intersects); + } + } + + clone() { + const ctor = this.constructor as new (splatData: SplatData) => this; + const clone = new ctor(this.splatData).copy(this); + return clone; + } + + copy(source: this, recursive?: boolean): this { + // Avoid copying the material and geometry as these are unique to the Splat + const material = this.material; + const geometry = this.geometry; + + super.copy(source, recursive); + // TODO: Copy over Splat specific properties + + this.material = material; + this.geometry = geometry; + return this; + } + + static makeUniforms() { + // Create uniforms used for Gsplat vertex and fragment shaders + const uniforms = { + // Opacity of the splat + opacity: { value: 1.0 }, + + // Total number of splats, active or not + numSplats: { value: 0 }, + + // Maximum distance (in stddevs) from Gsplat center to render + maxStdDev: { value: 1.0 }, + // Minimum pixel radius for splat rendering + minPixelRadius: { value: 0.0 }, + // Maximum pixel radius for splat rendering + maxPixelRadius: { value: 512.0 }, + // Minimum alpha value for splat rendering + minAlpha: { value: 0.5 * (1.0 / 255.0) }, + // Enable stochastic splat rendering + stochastic: { value: false }, + // Enable interpreting 0-thickness Gsplats as 2DGS + enable2DGS: { value: false }, + // Add to projected 2D splat covariance diagonal (thickens and brightens) + preBlurAmount: { value: 0.0 }, + // Add to 2D splat covariance diagonal and adjust opacity (anti-aliasing) + blurAmount: { value: 0.3 }, + + // Modulate Gaussian kernel falloff. 0 means "no falloff, flat shading", + // 1 is normal e^-x^2 falloff. + falloff: { value: 1.0 }, + // Clip Gsplats that are clipXY times beyond the +-1 frustum bounds + clipXY: { value: 1.4 }, + // Size of render viewport in pixels + renderSize: { value: new THREE.Vector2() }, + // Debug renderSize scale factor + focalAdjustment: { value: 1.0 }, + + // Time in seconds for time-based effects + time: { value: 0 }, + // Whether to encode Gsplat with linear RGB (for environment mapping) + encodeLinear: { value: false }, + }; + return uniforms; + } +} + +/** + * Shader hooks for customizing the shader used for rendering splats. + * This allows modifying the splats in object space, in world space + * as well as adjusting the splats color, opacity and shading. + */ +export type ShaderHooks = { + /** + * Hooks for the vertex shader. + */ + vertex?: { + /** + * Additional uniforms to add to the vertex shader. + */ + uniforms?: { [key: string]: THREE.IUniform & { type: string } }; + /** + * Shader chunk to include at the start of the vertex shader. + * This can be used to define additional methods and constant + * that can be used in the other hooks. + */ + global?: string; + /** + * Shader chunk for adjusting the splat in object space. + */ + objectModifier?: string; + /** + * Shader chunk for adjusting the splat in world space. + */ + worldModifier?: string; + /** + * Shader chunk for changing the color of the splat. + */ + splatColor?: string; + }; + /** + * Hooks for the fragment shader. + */ + fragment?: { + /** + * Additional uniforms to add to the fragment shader. + */ + uniforms?: { [key: string]: THREE.IUniform & { type: string } }; + /** + * Shader chunk to include at the start of the fragment shader. + * This can be used to define additional methods and constant + * that can be used in the other hooks. + */ + global?: string; + }; + /** + * Custom onBeforeCompile allowing the full shader code to be adjusted. + */ + onBeforeCompile?: typeof THREE.Material.prototype.onBeforeCompile; +}; diff --git a/src/SplatAccumulator.ts b/src/SplatAccumulator.ts deleted file mode 100644 index 1a0bd66..0000000 --- a/src/SplatAccumulator.ts +++ /dev/null @@ -1,120 +0,0 @@ -import * as THREE from "three"; - -import { PackedSplats } from "./PackedSplats"; -import type { - GsplatGenerator, - SplatGenerator, - SplatModifier, -} from "./SplatGenerator"; - -// SplatAccumulator helps manage the generation of splats from multiple -// SplatGenerators, keeping track of the splat mapping, coordinate system, -// and reference count. - -// A GeneratorMapping describes a Gsplat range that was generated, including -// which generator and its version number. -export type GeneratorMapping = { - node: SplatGenerator; - generator?: GsplatGenerator; - version: number; - base: number; - count: number; -}; - -export class SplatAccumulator { - splats = new PackedSplats(); - // The transform from Accumulator coordinate system to world coordinates. - toWorld = new THREE.Matrix4(); - // An array of all Gsplat mappings that were used for generation - mapping: GeneratorMapping[] = []; - // Number of SparkViewpoints (or other) that reference this accumulator, used - // to figure out when it can be recycled for use - refCount = 0; - - // Incremented every time the splats are updated/generated. - splatsVersion = -1; - // Incremented every time the splat mapping/layout is updated. - // Splat sort order can be reused between equivalent mapping versions. - mappingVersion = -1; - - ensureGenerate(maxSplats: number) { - if (this.splats.ensureGenerate(maxSplats)) { - // If we had to resize our PackedSplats then clear all previous mappings - this.mapping = []; - } - } - - // Generate all Gsplats from an array of generators - generateSplats({ - renderer, - modifier, - generators, - forceUpdate, - originToWorld, - }: { - renderer: THREE.WebGLRenderer; - modifier: SplatModifier; - generators: GeneratorMapping[]; - forceUpdate?: boolean; - originToWorld: THREE.Matrix4; - }) { - // Create a lookup from last SplatGenerator - const mapping = this.mapping.reduce((map, record) => { - map.set(record.node, record); - return map; - }, new Map()); - - // Run generators that are different from existing mapping - let updated = 0; - let numSplats = 0; - for (const { node, generator, version, base, count } of generators) { - const current = mapping.get(node); - if ( - forceUpdate || - generator !== current?.generator || - version !== current?.version || - base !== current?.base || - count !== current?.count - ) { - // Something is different from before so we should generate these Gsplats - if (generator && count > 0) { - const modGenerator = modifier.apply(generator); - try { - this.splats.generate({ - generator: modGenerator, - base, - count, - renderer, - }); - } catch (error) { - node.generator = undefined; - node.generatorError = error; - } - updated += 1; - } - } - numSplats = Math.max(numSplats, base + count); - } - - this.splats.numSplats = numSplats; - this.toWorld.copy(originToWorld); - this.mapping = generators; - return updated !== 0; - } - - // Check if this accumulator has exactly the same generator mapping as - // the previous one. If so, we can reuse the Gsplat sort order. - hasCorrespondence(other: SplatAccumulator) { - if (this.mapping.length !== other.mapping.length) { - return false; - } - return this.mapping.every(({ node, base, count }, i) => { - const { - node: otherNode, - base: otherBase, - count: otherCount, - } = other.mapping[i]; - return node === otherNode && base === otherBase && count === otherCount; - }); - } -} diff --git a/src/SplatEdit.ts b/src/SplatEdit.ts deleted file mode 100644 index 4591de8..0000000 --- a/src/SplatEdit.ts +++ /dev/null @@ -1,829 +0,0 @@ -import * as THREE from "three"; - -import { - Dyno, - DynoInt, - DynoUniform, - type DynoVal, - Gsplat, - unindent, - unindentLines, -} from "./dyno"; -import { newArray } from "./utils"; - -// Spark provides the ability to apply "edits" to Gsplats as part of the standard -// SplatMesh pipeline. These edits take the form of a sequence of operations, -// applied one at a time to the set of Gsplats in its packedSplats. Each operation -// evaluates a 7-dimensional field (RGBA and XYZ displacement) at each point in -// space that derives from N=1 or more Signed Distance Field shapes (such as spheres, -// boxes, planes, etc.), blended together and across inside-outisde boundaries. - -// The result is a an RGBA,XYZ value for each point in space, which combined with -// SplatEditRgbaBlendMode.MULTIPLY/SET_RGB/ADD_RGBA can be used to create special -// effects, for example simulating simple lighting or applying deformations in space, -// whose parameters can be updated each frame to create animated effects. - -// RGBA-XYZ values are computed by blending together values from all SDF shapes using -// the exponential "softmax" function, which is commutative (so blending order within -// a SplatEdit operation doesn't matter). The parameter SplatEdit.sdfSmooth controls -// the blending scale between SDF shapes, while SplatEdit.softEdge controls the scale -// of soft inside-outside shape edit blending. Their default values start at 0.0 and -// should be increased to soften the effect. - -// Note that XYZ displacement values are blended in the same way as RGBA, with a -// resulting displacement field that can be quite complex but "softly" blending -// between shapes. These RGBA-XYZ edits, along with time-based and overlapping -// fields can create many interesting animations and special effects, such as -// rippling leaves in the wind, an angry fire, or a looping water effects. Simply -// update the SplatEdit and SplatEditSdf objects and the operations will be applied -// immediately to the Gsplats in the scene. - -export enum SplatEditSdfType { - // ALL: Affects all points in space - ALL = "all", - // PLANE: Infinite plane (position, rotation) - PLANE = "plane", - // SPHERE: Sphere (position, radius) - SPHERE = "sphere", - // BOX: Rounded box (position, rotation, sizes, radius) - BOX = "box", - // ELLIPSOID: Ellipsoid (position, rotation, sizes) - ELLIPSOID = "ellipsoid", - // CYLINDER: Cylinder (position, rotation, radius, size_y) - CYLINDER = "cylinder", - // CAPSULE: Capsule (position, rotation, radius, size_y) - CAPSULE = "capsule", - // INFINITE_CONE: Infinite cone (position, rotation, radius=angle) - INFINITE_CONE = "infinite_cone", -} - -function sdfTypeToNumber(type: SplatEditSdfType) { - switch (type) { - case SplatEditSdfType.ALL: - return 0; - case SplatEditSdfType.PLANE: - return 1; - case SplatEditSdfType.SPHERE: - return 2; - case SplatEditSdfType.BOX: - return 3; - case SplatEditSdfType.ELLIPSOID: - return 4; - case SplatEditSdfType.CYLINDER: - return 5; - case SplatEditSdfType.CAPSULE: - return 6; - case SplatEditSdfType.INFINITE_CONE: - return 7; - default: - throw new Error(`Unknown SDF type: ${type}`); - } -} - -export enum SplatEditRgbaBlendMode { - // The RGBA of the splat is multiplied component-wise by the SDF’s - // RGBA value at that point in space. - MULTIPLY = "multiply", - // Ignore the Alpha value in the SDF, but set the splat’s RGB to - // equal the SDF’s RGB value at that point. - SET_RGB = "set_rgb", - // Add the SDF’s RGBA value at that point to the RGBA value of - // the Gsplat. This can produce hyper-saturated results, but is useful - // to easily “light up” areas. - ADD_RGBA = "add_rgba", -} - -function rgbaBlendModeToNumber(mode: SplatEditRgbaBlendMode) { - switch (mode) { - case SplatEditRgbaBlendMode.MULTIPLY: - return 0; - case SplatEditRgbaBlendMode.SET_RGB: - return 1; - case SplatEditRgbaBlendMode.ADD_RGBA: - return 2; - default: - throw new Error(`Unknown blend mode: ${mode}`); - } -} - -export type SplatEditSdfOptions = { - // The SDF shape type: ALL, PLANE, SPHERE, BOX, ELLIPSOID, CYLINDER, CAPSULE, - // or INFINITE_CONE. (default: SplatEditSdfType.SPHERE) - type?: SplatEditSdfType; - // Invert the SDF evaluation, swapping inside and outside regions. (default: false) - invert?: boolean; - // Opacity / "alpha" value used differently by blending modes (default: 1.0) - opacity?: number; - // RGB color applied within the shape. (default: new THREE.Color(1.0, 1.0, 1.0)) - color?: THREE.Color; - // XYZ displacement applied to splat positions inside the shape. - // (default: new THREE.Vector3(0.0, 0.0, 0.0)) - displace?: THREE.Vector3; - // Shape-specific size parameter: sphere radius, box corner rounding, - // cylinder/capsule radius, or for the infinite cone the angle factor - // (opening half-angle = π/4 × radius). - radius?: number; -}; - -export class SplatEditSdf extends THREE.Object3D { - type: SplatEditSdfType; - invert: boolean; - opacity: number; - color: THREE.Color; - displace: THREE.Vector3; - radius: number; - - constructor(options: SplatEditSdfOptions = {}) { - super(); - const { type, invert, opacity, color, displace, radius } = options; - this.type = type ?? SplatEditSdfType.SPHERE; - this.invert = invert ?? false; - this.opacity = opacity ?? 1.0; - this.color = color ?? new THREE.Color(1.0, 1.0, 1.0); - this.displace = displace ?? new THREE.Vector3(0.0, 0.0, 0.0); - this.radius = radius ?? 0.0; - } -} - -export type SplatEditOptions = { - // Name of this edit operation. If you omit it, a default "Edit 1", "Edit 2", ... - // is assigned. - name?: string; - // How the SDF’s RGBA modifies each splat’s RGBA: multiply, overwrite RGB, - // or add RGBA. (default: MULTIPLY) - rgbaBlendMode?: SplatEditRgbaBlendMode; - // Smoothing (in world‐space units) for blending between multiple SDF shapes - // at their boundaries. (default: 0.0) - sdfSmooth?: number; - // Soft‐edge falloff radius (in world‐space units) around each SDF shape’s surface. - // (default: 0.0) - softEdge?: number; - // Invert the SDF evaluation (inside/outside swap). (default: false) - invert?: boolean; - // Explicit array of SplatEditSdf objects to include. If null, any child - // SplatEditSdf instances are used. - sdfs?: SplatEditSdf[]; -}; - -export class SplatEdit extends THREE.Object3D { - // ordering used to apply SplatEdit operations to Gsplats. This is implicitly - // increased with each new SplatEdit. Reassigning ordering can be used to - // reorder the operations. - ordering: number; - rgbaBlendMode: SplatEditRgbaBlendMode; - sdfSmooth: number; - softEdge: number; - invert: boolean; - - // Optional list of explicit SDFs to including in this edit. If it is null, then - // any SplatEditSdf children in the scene graph will be added automatically. - sdfs: SplatEditSdf[] | null; - - // The next ordering number to use for a new SplatEdit, auto-incremented - static nextOrdering = 1; - - constructor(options: SplatEditOptions = {}) { - const { - name, - rgbaBlendMode = SplatEditRgbaBlendMode.MULTIPLY, - sdfSmooth = 0.0, - softEdge = 0.0, - invert = false, - sdfs = null, - } = options; - - super(); - this.rgbaBlendMode = rgbaBlendMode; - this.sdfSmooth = sdfSmooth; - this.softEdge = softEdge; - this.invert = invert; - this.sdfs = sdfs; - // Assign and auto-increment unique ordering number for this edit - this.ordering = SplatEdit.nextOrdering++; - // Automatically assign a default name if not provided - this.name = name ?? `Edit ${this.ordering}`; - } - - addSdf(sdf: SplatEditSdf) { - if (this.sdfs == null) { - this.sdfs = []; - } - if (!this.sdfs.includes(sdf)) { - this.sdfs.push(sdf); - } - } - - removeSdf(sdf: SplatEditSdf) { - if (this.sdfs == null) { - return; - } - this.sdfs = this.sdfs.filter((s) => s !== sdf); - } -} - -// Dyno implementation of RGBA-XYZ SDF editing. -// The SDFs are encoded in a texture while the edits are encoded -// as a uniform uvec4 array. - -export class SplatEdits { - // Maximum number of SDFs allocated - maxSdfs: number; - // Number of SDFs currently in use - numSdfs: number; - // Encoded SDF data - sdfData: Uint32Array; - // Float interpretation of SDF data - sdfFloatData: Float32Array; - // Texture with encoded SDF data - sdfTexture: THREE.DataTexture; - // An SdfArray dyno uniform - dynoSdfArray: DynoUniform; - - // Maximum number of edits allocated - maxEdits: number; - // Number of edits currently in use - numEdits: number; - // Encoded edit data - editData: Uint32Array; - // Float interpretation of edit data - editFloatData: Float32Array; - // A dyno uniform for the number of edits - dynoNumEdits: DynoUniform<"int", "numEdits">; - // A dyno uniform for the encoded edits, one uvec4 per edit - dynoEdits: DynoUniform<"uvec4", "edits">; - - constructor({ maxSdfs, maxEdits }: { maxSdfs?: number; maxEdits?: number }) { - // Allocate at least 16 SDFs for efficiency - this.maxSdfs = Math.max(16, maxSdfs ?? 0); - this.numSdfs = 0; - - // Allocate space: 8 x (u)vec4 values per SDF, Uint32 and Float32 arrays - this.sdfData = new Uint32Array(this.maxSdfs * 8 * 4); - this.sdfFloatData = new Float32Array(this.sdfData.buffer); - this.sdfTexture = this.newSdfTexture(this.sdfData, this.maxSdfs); - this.dynoSdfArray = new DynoUniform({ - key: "sdfArray", - type: SdfArray, - globals: () => [defineSdfArray], - value: { - numSdfs: 0, - sdfTexture: this.sdfTexture, - }, - update: (uniform) => { - uniform.numSdfs = this.numSdfs; - uniform.sdfTexture = this.sdfTexture; - return uniform; - }, - }); - - // Allocate at least 16 edits slots for efficiency - this.maxEdits = Math.max(16, maxEdits ?? 0); - this.numEdits = 0; - // Allocate space: 1 uvec4 per edit - this.editData = new Uint32Array(this.maxEdits * 4); - this.editFloatData = new Float32Array(this.editData.buffer); - this.dynoNumEdits = new DynoInt({ value: 0 }); - this.dynoEdits = this.newEdits(this.editData, this.maxEdits); - } - - private newSdfTexture(data: Uint32Array, maxSdfs: number) { - const texture = new THREE.DataTexture( - data, - 8, - maxSdfs, - THREE.RGBAIntegerFormat, - THREE.UnsignedIntType, - ); - texture.internalFormat = "RGBA32UI"; - texture.needsUpdate = true; - return texture; - } - - private newEdits(data: Uint32Array, maxEdits: number) { - return new DynoUniform({ - key: "edits", - type: "uvec4", - count: maxEdits, - globals: () => [defineEdit], - value: data, - }); - } - - // Ensure our SDF texture and edits uniform array have enough capacity. - // Reallocate if not. - private ensureCapacity({ - maxSdfs, - maxEdits, - }: { maxSdfs: number; maxEdits: number }): boolean { - let dynoUpdated = false; - if (maxSdfs > this.sdfTexture.image.height) { - this.sdfTexture.dispose(); - // At least double the size to avoid frequent reallocations - this.maxSdfs = Math.max(this.maxSdfs * 2, maxSdfs); - this.sdfData = new Uint32Array(this.maxSdfs * 8 * 4); - this.sdfFloatData = new Float32Array(this.sdfData.buffer); - this.sdfTexture = this.newSdfTexture(this.sdfData, this.maxSdfs); - } - if (maxEdits > (this.dynoEdits.count ?? 0)) { - // At least double the size to avoid frequent reallocations - this.maxEdits = Math.max(this.maxEdits * 2, maxEdits); - this.editData = new Uint32Array(this.maxEdits * 4); - this.editFloatData = new Float32Array(this.editData.buffer); - this.dynoEdits = this.newEdits(this.editData, this.maxEdits); - dynoUpdated = true; - } - return dynoUpdated; - } - - private updateEditData(offset: number, value: number): boolean { - // Update an edit uint32 value and return true if it changed - const updated = this.editData[offset] !== value; - this.editData[offset] = value; - return updated; - } - - private updateEditFloatData(offset: number, value: number): boolean { - // Update an edit float32 value and return true if it changed - tempFloat32[0] = value; - const updated = this.editFloatData[offset] !== tempFloat32[0]; - if (updated) { - this.editFloatData[offset] = tempFloat32[0]; - } - return updated; - } - - private encodeEdit( - editIndex: number, - { - sdfFirst, - sdfCount, - invert, - rgbaBlendMode, - softEdge, - sdfSmooth, - }: { - sdfFirst: number; - sdfCount: number; - invert: boolean; - rgbaBlendMode: number; - softEdge: number; - sdfSmooth: number; - }, - ): boolean { - const base = editIndex * 4; - let updated = false; - // Encode the edit fields into the editData array and check if any changed - updated = - this.updateEditData(base + 0, rgbaBlendMode | (invert ? 1 << 8 : 0)) || - updated; - updated = - this.updateEditData(base + 1, sdfFirst | (sdfCount << 16)) || updated; - updated = this.updateEditFloatData(base + 2, softEdge) || updated; - updated = this.updateEditFloatData(base + 3, sdfSmooth) || updated; - return updated; - } - - private updateSdfData(offset: number, value: number): boolean { - // Update an SDF uint32 value and return true if it changed - const updated = this.sdfData[offset] !== value; - this.sdfData[offset] = value; - return updated; - } - - private updateSdfFloatData(offset: number, value: number): boolean { - // Update an SDF float32 value and return true if it changed - tempFloat32[0] = value; - const updated = this.sdfFloatData[offset] !== tempFloat32[0]; - if (updated) { - this.sdfFloatData[offset] = tempFloat32[0]; - } - return updated; - } - - private encodeSdf( - sdfIndex: number, - { - sdfType, - invert, - center, - quaternion, - scale, - sizes, - }: { - sdfType: number; - invert?: boolean; - center?: THREE.Vector3; - quaternion?: THREE.Quaternion; - scale?: THREE.Vector3; - sizes?: THREE.Vector4; - }, - values: THREE.Vector4[], - ): boolean { - // Encode the SDF fields into the sdfData array and check if any changed - const base = sdfIndex * (8 * 4); - const flags = sdfType | (invert ? 1 << 8 : 0); - let updated = false; - - updated = this.updateSdfFloatData(base + 0, center?.x ?? 0) || updated; - updated = this.updateSdfFloatData(base + 1, center?.y ?? 0) || updated; - updated = this.updateSdfFloatData(base + 2, center?.z ?? 0) || updated; - updated = this.updateSdfData(base + 3, flags) || updated; - - updated = this.updateSdfFloatData(base + 4, quaternion?.x ?? 0) || updated; - updated = this.updateSdfFloatData(base + 5, quaternion?.y ?? 0) || updated; - updated = this.updateSdfFloatData(base + 6, quaternion?.z ?? 0) || updated; - updated = this.updateSdfFloatData(base + 7, quaternion?.w ?? 0) || updated; - - updated = this.updateSdfFloatData(base + 8, scale?.x ?? 0) || updated; - updated = this.updateSdfFloatData(base + 9, scale?.y ?? 0) || updated; - updated = this.updateSdfFloatData(base + 10, scale?.z ?? 0) || updated; - updated = this.updateSdfData(base + 11, 0) || updated; - - updated = this.updateSdfFloatData(base + 12, sizes?.x ?? 0) || updated; - updated = this.updateSdfFloatData(base + 13, sizes?.y ?? 0) || updated; - updated = this.updateSdfFloatData(base + 14, sizes?.z ?? 0) || updated; - updated = this.updateSdfFloatData(base + 15, sizes?.w ?? 0) || updated; - - const nValues = Math.min(4, values.length); - for (let i = 0; i < nValues; ++i) { - const vBase = base + 16 + i * 4; - updated = this.updateSdfFloatData(vBase + 0, values[i].x) || updated; - updated = this.updateSdfFloatData(vBase + 1, values[i].y) || updated; - updated = this.updateSdfFloatData(vBase + 2, values[i].z) || updated; - updated = this.updateSdfFloatData(vBase + 3, values[i].w) || updated; - } - return updated; - } - - // Update the SDFs and edits from an array of SplatEdits and their - // associated SplatEditSdfs, updating it for the dyno shader program. - update(edits: { edit: SplatEdit; sdfs: SplatEditSdf[] }[]): { - updated: boolean; - dynoUpdated: boolean; - } { - const sdfCount = edits.reduce((total, { sdfs }) => total + sdfs.length, 0); - const dynoUpdated = this.ensureCapacity({ - maxEdits: edits.length, - maxSdfs: sdfCount, - }); - - const values = [new THREE.Vector4(), new THREE.Vector4()]; - const center = new THREE.Vector3(); - const quaternion = new THREE.Quaternion(); - const scale = new THREE.Vector3(); - const sizes = new THREE.Vector4(); - - let sdfIndex = 0; - let updated = dynoUpdated; - - if (edits.length !== this.dynoNumEdits.value) { - this.dynoNumEdits.value = edits.length; - this.numEdits = edits.length; - updated = true; - } - - for (const [editIndex, { edit, sdfs }] of edits.entries()) { - updated = - this.encodeEdit(editIndex, { - sdfFirst: sdfIndex, - sdfCount: sdfs.length, - invert: edit.invert, - rgbaBlendMode: rgbaBlendModeToNumber(edit.rgbaBlendMode), - softEdge: edit.softEdge, - sdfSmooth: edit.sdfSmooth, - }) || updated; - - let sdfUpdated = false; - for (const sdf of sdfs) { - sizes.set(sdf.scale.x, sdf.scale.y, sdf.scale.z, sdf.radius); - // Temporarily set the SDF scale to 1.0 to get the world-to-SDF - // transform without scaling. The SDF treats the scale separately. - sdf.scale.setScalar(1.0); - sdf.updateMatrixWorld(); - const worldToSdf = sdf.matrixWorld.clone().invert(); - worldToSdf.decompose(center, quaternion, scale); - - sdf.scale.set(sizes.x, sizes.y, sizes.z); - sdf.updateMatrixWorld(); - - values[0].set(sdf.color.r, sdf.color.g, sdf.color.b, sdf.opacity); - values[1].set(sdf.displace.x, sdf.displace.y, sdf.displace.z, 1.0); - - sdfUpdated = - this.encodeSdf( - sdfIndex, - { - sdfType: sdfTypeToNumber(sdf.type), - invert: sdf.invert, - center, - quaternion, - scale, - sizes, - }, - values, - ) || sdfUpdated; - - sdfIndex += 1; - } - this.numSdfs = sdfIndex; - if (sdfUpdated) { - this.sdfTexture.needsUpdate = true; - } - updated ||= sdfUpdated; - } - return { updated, dynoUpdated }; - } - - // Modify a Gsplat in a dyno shader program using the current edits and SDFs. - modify(gsplat: DynoVal): DynoVal { - return applyGsplatRgbaDisplaceEdits( - gsplat, - this.dynoSdfArray, - this.dynoNumEdits, - this.dynoEdits, - ); - } -} - -// Dyno types and components: - -// An SdfArray contains a collection of SDFs encoded in a texture. -// Each SDF has a type and geometric parameters, but also encodes -// 4 x vec4 values, which can all be blended across multiple SDFs. -// The SplatEdit system uses 7 of these 16 values to encode RGBA-XYZ edits, -// but more can be added, and these SDFs can be used for entirely different -// purposes as well. - -export const SdfArray = { type: "SdfArray" } as { type: "SdfArray" }; - -export const defineSdfArray = unindent(` - struct SdfArray { - int numSdfs; - usampler2D sdfTexture; - }; - - void unpackSdfArray( - usampler2D sdfTexture, int sdfIndex, out uint flags, - out vec3 center, out vec4 quaternion, out vec3 scale, out vec4 sizes, - int numValues, out vec4 values[4] - ) { - uvec4 temp = texelFetch(sdfTexture, ivec2(0, sdfIndex), 0); - flags = temp.w; - center = vec3(uintBitsToFloat(temp.x), uintBitsToFloat(temp.y), uintBitsToFloat(temp.z)); - - temp = texelFetch(sdfTexture, ivec2(1, sdfIndex), 0); - quaternion = vec4(uintBitsToFloat(temp.x), uintBitsToFloat(temp.y), uintBitsToFloat(temp.z), uintBitsToFloat(temp.w)); - - temp = texelFetch(sdfTexture, ivec2(2, sdfIndex), 0); - scale = vec3(uintBitsToFloat(temp.x), uintBitsToFloat(temp.y), uintBitsToFloat(temp.z)); - - temp = texelFetch(sdfTexture, ivec2(3, sdfIndex), 0); - sizes = vec4(uintBitsToFloat(temp.x), uintBitsToFloat(temp.y), uintBitsToFloat(temp.z), uintBitsToFloat(temp.w)); - - for (int i = 0; i < numValues; ++i) { - temp = texelFetch(sdfTexture, ivec2(4 + i, sdfIndex), 0); - values[i] = vec4(uintBitsToFloat(temp.x), uintBitsToFloat(temp.y), uintBitsToFloat(temp.z), uintBitsToFloat(temp.w)); - } - } - - const uint SDF_FLAG_TYPE = 0xFFu; - const uint SDF_FLAG_INVERT = 1u << 8u; - - const uint SDF_TYPE_ALL = 0u; - const uint SDF_TYPE_PLANE = 1u; - const uint SDF_TYPE_SPHERE = 2u; - const uint SDF_TYPE_BOX = 3u; - const uint SDF_TYPE_ELLIPSOID = 4u; - const uint SDF_TYPE_CYLINDER = 5u; - const uint SDF_TYPE_CAPSULE = 6u; - const uint SDF_TYPE_INFINITE_CONE = 7u; - - float evaluateSdfArray( - usampler2D sdfTexture, int numSdfs, int sdfFirst, int sdfCount, vec3 pos, - float smoothK, int numValues, out vec4 outValues[4] - ) { - float distanceAccum = (smoothK == 0.0) ? 1.0 / 0.0 : 0.0; - float maxExp = -1.0 / 0.0; - for (int i = 0; i < numValues; ++i) { - outValues[i] = vec4(0.0); - } - - uint flags; - vec3 center, scale; - vec4 quaternion, sizes; - vec4 values[4]; - - int sdfLast = min(sdfFirst + sdfCount, numSdfs); - for (int index = sdfFirst; index < sdfLast; ++index) { - unpackSdfArray(sdfTexture, index, flags, center, quaternion, scale, sizes, numValues, values); - uint sdfType = flags & SDF_FLAG_TYPE; - vec3 sdfPos = quatVec(quaternion, pos * scale) + center; - - float distance; - switch (sdfType) { - case SDF_TYPE_ALL: - distance = -1.0 / 0.0; - break; - case SDF_TYPE_PLANE: { - distance = sdfPos.z; - break; - } - case SDF_TYPE_SPHERE: { - distance = length(sdfPos) - sizes.w; - break; - } - case SDF_TYPE_BOX: { - vec3 q = abs(sdfPos) - sizes.xyz + sizes.w; - distance = length(max(q, 0.0)) + min(max(q.x, max(q.y, q.z)), 0.0) - sizes.w; - break; - } - case SDF_TYPE_ELLIPSOID: { - vec3 sizes = sizes.xyz; - float k0 = length(sdfPos / sizes); - float k1 = length(sdfPos / dot(sizes, sizes)); - distance = k0 * (k0 - 1.0) / k1; - break; - } - case SDF_TYPE_CYLINDER: { - vec2 d = abs(vec2(length(sdfPos.xz), sdfPos.y)) - sizes.wy; - distance = min(max(d.x, d.y), 0.0) + length(max(d, 0.0)); - break; - } - case SDF_TYPE_CAPSULE: { - sdfPos.y -= clamp(sdfPos.y, -0.5 * sizes.y, 0.5 * sizes.y); - distance = length(sdfPos) - sizes.w; - break; - } - case SDF_TYPE_INFINITE_CONE: { - float angle = 0.25 * PI * sizes.w; - vec2 c = vec2(sin(angle), cos(angle)); - vec2 q = vec2(length(sdfPos.xy), -sdfPos.z); - float d = length(q - c * max(dot(q, c), 0.0)); - distance = d * (((q.x * c.y - q.y * c.x) < 0.0) ? -1.0 : 1.0); - break; - } - } - - if ((flags & SDF_FLAG_INVERT) != 0u) { - distance = -distance; - } - - if (smoothK == 0.0) { - if (distance < distanceAccum) { - distanceAccum = distance; - for (int i = 0; i < numValues; ++i) { - outValues[i] = values[i]; - } - } - } else { - float scaledDistance = -distance / smoothK; - if (scaledDistance > maxExp) { - float scale = exp(maxExp - scaledDistance); - distanceAccum *= scale; - for (int i = 0; i < numValues; ++i) { - outValues[i] *= scale; - } - maxExp = scaledDistance; - } - - float weight = exp(scaledDistance - maxExp); - distanceAccum += weight; - for (int i = 0; i < numValues; ++i) { - outValues[i] += weight * values[i]; - } - } - } - - if (smoothK == 0.0) { - return distanceAccum; - } else { - // Very distant SDFs may result in 0 accumulation - if (distanceAccum == 0.0) { - return 1.0 / 0.0; - } - for (int i = 0; i < numValues; ++i) { - outValues[i] /= distanceAccum; - } - return (-log(distanceAccum) - maxExp) * smoothK; - } - } - - float modulateSdfArray( - usampler2D sdfTexture, int numSdfs, int sdfFirst, int sdfCount, vec3 pos, - float smoothK, int numValues, out vec4 values[4], - float softEdge, bool invert - ) { - float distance = evaluateSdfArray(sdfTexture, numSdfs, sdfFirst, sdfCount, pos, smoothK, numValues, values); - if (invert) { - distance = -distance; - } - - return (softEdge == 0.0) ? ((distance < 0.0) ? 1.0 : 0.0) - : clamp(-distance / softEdge + 0.5, 0.0, 1.0); - } -`); - -export const defineEdit = unindent(` - const uint EDIT_FLAG_BLEND = 0xFFu; - const uint EDIT_BLEND_MULTIPLY = 0u; - const uint EDIT_BLEND_SET_RGB = 1u; - const uint EDIT_BLEND_ADD_RGBA = 2u; - const uint EDIT_FLAG_INVERT = 0x100u; - - void decodeEdit( - uvec4 packedEdit, out int sdfFirst, out int sdfCount, - out bool invert, out uint rgbaBlendMode, out float softEdge, out float sdfSmooth - ) { - rgbaBlendMode = packedEdit.x & EDIT_FLAG_BLEND; - invert = (packedEdit.x & EDIT_FLAG_INVERT) != 0u; - - sdfFirst = int(packedEdit.y & 0xFFFFu); - sdfCount = int(packedEdit.y >> 16u); - - softEdge = uintBitsToFloat(packedEdit.z); - sdfSmooth = uintBitsToFloat(packedEdit.w); - } - - void applyRgbaDisplaceEdit( - usampler2D sdfTexture, int numSdfs, int sdfFirst, int sdfCount, inout vec3 pos, - float smoothK, float softEdge, bool invert, uint rgbaBlendMode, inout vec4 rgba - ) { - vec4 values[4]; - float modulate = modulateSdfArray(sdfTexture, numSdfs, sdfFirst, sdfCount, pos, smoothK, 2, values, softEdge, invert); - // On Android, moving values[0] is necessary to work around a compiler bug. - vec4 sdfRgba = values[0]; - vec4 sdfDisplaceScale = values[1]; - - vec4 target; - switch (rgbaBlendMode) { - case EDIT_BLEND_MULTIPLY: - target = rgba * sdfRgba; - break; - case EDIT_BLEND_SET_RGB: - target = vec4(sdfRgba.rgb, rgba.a * sdfRgba.a); - break; - case EDIT_BLEND_ADD_RGBA: - target = rgba + sdfRgba; - break; - default: - // Debug output if blend mode not set - target = vec4(fract(pos), 1.0); - } - rgba = mix(rgba, target, modulate); - pos += sdfDisplaceScale.xyz * modulate; - } - - void applyPackedRgbaDisplaceEdit(uvec4 packedEdit, usampler2D sdfTexture, int numSdfs, inout vec3 pos, inout vec4 rgba) { - int sdfFirst, sdfCount; - bool invert; - uint rgbaBlendMode; - float softEdge, sdfSmooth; - decodeEdit(packedEdit, sdfFirst, sdfCount, invert, rgbaBlendMode, softEdge, sdfSmooth); - applyRgbaDisplaceEdit(sdfTexture, numSdfs, sdfFirst, sdfCount, pos, sdfSmooth, softEdge, invert, rgbaBlendMode, rgba); - } -`); - -function applyGsplatRgbaDisplaceEdits( - gsplat: DynoVal, - sdfArray: DynoVal, - numEdits: DynoVal<"int">, - rgbaDisplaceEdits: DynoVal<"uvec4">, -): DynoVal { - const dyno = new Dyno< - { - gsplat: typeof Gsplat; - sdfArray: typeof SdfArray; - numEdits: "int"; - rgbaDisplaceEdits: "uvec4"; - }, - { gsplat: typeof Gsplat } - >({ - inTypes: { - gsplat: Gsplat, - sdfArray: SdfArray, - numEdits: "int", - rgbaDisplaceEdits: "uvec4", - }, - outTypes: { gsplat: Gsplat }, - globals: () => [defineSdfArray, defineEdit], - inputs: { gsplat, sdfArray, numEdits, rgbaDisplaceEdits }, - statements: ({ inputs, outputs }) => { - const { sdfArray, numEdits, rgbaDisplaceEdits } = inputs; - const { gsplat } = outputs; - return unindentLines(` - ${gsplat} = ${inputs.gsplat}; - if (isGsplatActive(${gsplat}.flags)) { - for (int editIndex = 0; editIndex < ${numEdits}; ++editIndex) { - applyPackedRgbaDisplaceEdit( - ${rgbaDisplaceEdits}[editIndex], ${sdfArray}.sdfTexture, ${sdfArray}.numSdfs, - ${gsplat}.center, ${gsplat}.rgba - ); - } - } - `); - }, - }); - return dyno.outputs.gsplat; -} - -const tempFloat32 = new Float32Array(1); diff --git a/src/SplatGenerator.ts b/src/SplatGenerator.ts deleted file mode 100644 index e566723..0000000 --- a/src/SplatGenerator.ts +++ /dev/null @@ -1,233 +0,0 @@ -import * as THREE from "three"; -import type { SplatEdit } from "./SplatEdit"; -import { - type Dyno, - DynoFloat, - type DynoVal, - DynoVec3, - DynoVec4, - Gsplat, - dynoBlock, - transformDir, - transformGsplat, - transformPos, -} from "./dyno"; - -// A GsplatGenerator is a dyno program that maps an index to a Gsplat's properties - -export type GsplatGenerator = Dyno<{ index: "int" }, { gsplat: typeof Gsplat }>; - -// A GsplatModifier is a dyno program that inputs a Gsplat, modifies, and outputs it - -export type GsplatModifier = Dyno< - { gsplat: typeof Gsplat }, - { gsplat: typeof Gsplat } ->; - -// A SplatModifier is a utility class to apply a GsplatModifier to -// a GsplatGenerator pipeline, caching the combined result for efficiency. - -export class SplatModifier { - modifier: GsplatModifier; - cache: Map; - - constructor(modifier: GsplatModifier) { - this.modifier = modifier; - this.cache = new Map(); - } - - apply(generator: GsplatGenerator): GsplatGenerator { - let modified = this.cache.get(generator); - if (!modified) { - modified = dynoBlock( - { index: "int" }, - { gsplat: Gsplat }, - ({ index }) => { - const { gsplat } = generator.apply({ index }); - return this.modifier.apply({ gsplat }); - }, - ); - this.cache.set(generator, modified); - } - return modified; - } -} - -// A SplatTransformer is a utility class to apply a transform to a Gsplat -// via a scale, rotation, and translation. Scale is a single float because -// anisotropic scaling of Gsplats is not supported. - -export class SplatTransformer { - scale: DynoFloat; - rotate: DynoVec4; - translate: DynoVec3; - - // Create the dyno uniforms that parameterize the transform, setting them - // to initial values that are different from any valid transform. - constructor() { - this.scale = new DynoFloat({ value: Number.NEGATIVE_INFINITY }); - this.rotate = new DynoVec4({ - value: new THREE.Quaternion( - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - ), - }); - this.translate = new DynoVec3({ - value: new THREE.Vector3( - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - ), - }); - } - - // Apply the transform to a Vec3 position in a dyno program. - apply(position: DynoVal<"vec3">): DynoVal<"vec3"> { - return transformPos(position, { - scale: this.scale, - rotate: this.rotate, - translate: this.translate, - }); - } - - applyDir(dir: DynoVal<"vec3">): DynoVal<"vec3"> { - return transformDir(dir, { - rotate: this.rotate, - }); - } - - // Apply the transform to a Gsplat in a dyno program. - applyGsplat(gsplat: DynoVal): DynoVal { - return transformGsplat(gsplat, { - scale: this.scale, - rotate: this.rotate, - translate: this.translate, - }); - } - - // Update the uniforms to match the given transform matrix. - updateFromMatrix(transform: THREE.Matrix4) { - const scale = new THREE.Vector3(); - const quaternion = new THREE.Quaternion(); - const position = new THREE.Vector3(); - transform.decompose(position, quaternion, scale); - const newScale = (scale.x + scale.y + scale.z) / 3; - - let updated = false; - if (newScale !== this.scale.value) { - this.scale.value = newScale; - updated = true; - } - if (!position.equals(this.translate.value)) { - this.translate.value.copy(position); - updated = true; - } - if (!quaternion.equals(this.rotate.value)) { - this.rotate.value.copy(quaternion); - updated = true; - } - return updated; - } - - // Update this transform to match the object's to-world transform. - update(object: THREE.Object3D): boolean { - object.updateMatrixWorld(); - return this.updateFromMatrix(object.matrixWorld); - } -} - -// SplatGenerator is an Object3D that can be placed anywhere in the scene -// to generate Gsplats into the world for SparkRenderer. All Gsplats from -// SplatGenerators across the scene will be accumulated into a single -// SplatAccumulator, which are sorted and rendered together. -// -// Each SplatGenerator has two main properties: -// - numSplats: the number of Gsplats to generate -// - generator: a GsplatGenerator dyno program that maps a splat index -// to a Gsplat's properties -// Each of these properties can be changed at anytime, however changing -// numSplats means we no longer have a correspondence between Gsplats -// in successive frames, meaning we can't reuse the previous Gsplat sort -// order. Similarly, changing the generator requires re-generating the -// shader program, which will trigger a GPU shader compilation the first -// time (possibly a perceptible "hickup" in the framerate) but is cached -// subsequence times if the generator is the same as one that was used previously. -// -// A SplatGenerator also has a custom frameUpdate function that is called -// on each execution, allowing you to update uniforms or other parameters that -// affect the generation. If the Gsplats are changed, you must call -// updateVersion() (alternatively, set needsUpdate to true) to trigger a -// re-generation of the Gsplats for this SplatGenerator. - -export class SplatGenerator extends THREE.Object3D { - numSplats: number; - generator?: GsplatGenerator; - generatorError?: unknown; - frameUpdate?: ({ - object, - time, - deltaTime, - viewToWorld, - globalEdits, - }: { - object: SplatGenerator; - time: number; - deltaTime: number; - viewToWorld: THREE.Matrix4; - globalEdits: SplatEdit[]; - }) => void; - version: number; - - constructor({ - numSplats, - generator, - construct, - update, - }: { - numSplats?: number; - generator?: GsplatGenerator; - construct?: (object: SplatGenerator) => { - generator?: GsplatGenerator; - numSplats?: number; - frameUpdate?: (object: SplatGenerator) => void; - }; - update?: ({ - object, - time, - deltaTime, - viewToWorld, - globalEdits, - }: { - object: SplatGenerator; - time: number; - deltaTime: number; - viewToWorld: THREE.Matrix4; - globalEdits: SplatEdit[]; - }) => void; - }) { - super(); - - this.numSplats = numSplats ?? 0; - this.generator = generator; - this.frameUpdate = update; - this.version = 0; - - if (construct) { - const constructed = construct(this); - // If we returned something, update our properties - Object.assign(this, constructed); - } - } - - updateVersion() { - this.version += 1; - } - - set needsUpdate(value: boolean) { - if (value) { - this.updateVersion(); - } - } -} diff --git a/src/SplatGeometry.ts b/src/SplatGeometry.ts index 3e82e61..a550652 100644 --- a/src/SplatGeometry.ts +++ b/src/SplatGeometry.ts @@ -1,44 +1,105 @@ import * as THREE from "three"; +import type { SplatData } from "./Splat"; -// SplatGeometry is an internal class used by SparkRenderer to render a collection -// of Gsplats in a single draw call by extending THREE.InstancedBufferGeometry. -// Each Gsplat is drawn as two triangles, with the order of the Gsplats determined -// by the instance attribute "ordering". - +/** + * Dedicated geometry for rendering splats using instancing. + * Each splat is drawn as two triangles, with the order determined by the + * instance attribute "splatIndex". + */ export class SplatGeometry extends THREE.InstancedBufferGeometry { - ordering: Uint32Array; - attribute: THREE.InstancedBufferAttribute; + attribute?: SplatIndexAttribute; - constructor(ordering: Uint32Array, activeSplats: number) { + constructor() { super(); - this.ordering = ordering; - this.setAttribute("position", new THREE.BufferAttribute(QUAD_VERTICES, 3)); this.setIndex(new THREE.BufferAttribute(QUAD_INDICES, 1)); - // Hack to work around Three.js - // @ts-ignore - this._maxInstanceCount = ordering.length; - this.instanceCount = activeSplats; + this.instanceCount = 0; - this.attribute = new THREE.InstancedBufferAttribute(ordering, 1, false, 1); - this.attribute.setUsage(THREE.DynamicDrawUsage); - this.setAttribute("splatIndex", this.attribute); + this.boundingSphere = new THREE.Sphere(); + this.boundingBox = new THREE.Box3(); } - update(ordering: Uint32Array, activeSplats: number) { - this.ordering = ordering; - this.attribute.array = ordering; + update( + renderer: THREE.WebGLRenderer, + ordering: Uint32Array, + activeSplats: number, + ) { + if (!this.attribute) { + this.attribute = new SplatIndexAttribute(renderer, ordering); + this.setAttribute( + "splatIndex", + this.attribute as unknown as THREE.InstancedBufferAttribute, + ); + } + + this.attribute.update(renderer, ordering, activeSplats); this.instanceCount = activeSplats; - this.attribute.addUpdateRange(0, activeSplats); - this.attribute.needsUpdate = true; + } + + updateBounds(splatData: SplatData) { + if (!this.boundingSphere) { + this.boundingSphere = new THREE.Sphere(); + } + if (!this.boundingBox) { + this.boundingBox = new THREE.Box3(); + } + + // Empty the bounding shapes + this.boundingSphere.makeEmpty(); + this.boundingBox.makeEmpty(); + + // Note: since the sphere is at the origin, simplify the calculation + // by only computing the max squared radius of the splats. + let maxRadiusSquared = 0; + splatData.iterateCenters((i, x, y, z) => { + tempV3.set(x, y, z); + maxRadiusSquared = Math.max(maxRadiusSquared, tempV3.lengthSq()); + }); + const radius = Math.sqrt(maxRadiusSquared); + this.boundingSphere.radius = radius; + + // Determine the bounding box naively on the sphere + this.boundingBox.min.set(-radius, -radius, -radius); + this.boundingBox.max.set(radius, radius, radius); } } +const tempV3 = new THREE.Vector3(); + // Each instance draws to triangles covering a quad over coords (-1,-1,0)..(1,1,0) const QUAD_VERTICES = new Float32Array([ -1, -1, 0, 1, -1, 0, 1, 1, 0, -1, 1, 0, ]); const QUAD_INDICES = new Uint16Array([0, 1, 2, 0, 2, 3]); + +/** + * Dedicated GLBufferAttribute for the splat index to allow uploading the latest + * values from the onBeforeRender hook to avoid 1-frame latency on sort results. + */ +export class SplatIndexAttribute extends THREE.GLBufferAttribute { + public isInstancedBufferAttribute = true; + public isGLInstancedBufferAttribute = true; + public meshPerAttribute: number; + public data: Uint32Array; + + constructor(renderer: THREE.WebGLRenderer, array: Uint32Array) { + const gl = renderer.getContext(); + const buffer = gl.createBuffer(); + gl.bindBuffer(gl.ARRAY_BUFFER, buffer); + gl.bufferData(gl.ARRAY_BUFFER, array, gl.DYNAMIC_DRAW); + super(buffer, gl.UNSIGNED_INT, 1, 4, array.length); + this.meshPerAttribute = 1; + this.data = array; + } + + update(renderer: THREE.WebGLRenderer, array: Uint32Array, count: number) { + this.data = array; + + const gl = renderer.getContext(); + gl.bindBuffer(gl.ARRAY_BUFFER, this.buffer); + gl.bufferSubData(gl.ARRAY_BUFFER, 0, array, 0, count); + } +} diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts index 9346bf6..65d7185 100644 --- a/src/SplatLoader.ts +++ b/src/SplatLoader.ts @@ -1,34 +1,41 @@ import { unzipSync } from "fflate"; -import { FileLoader, Loader, type LoadingManager } from "three"; +import * as THREE from "three"; +import { Splat } from "./Splat"; +import { withWorkerCall } from "./SplatWorker"; import { - DEFAULT_SPLAT_ENCODING, - PackedSplats, - type SplatEncoding, -} from "./PackedSplats"; -import { SplatMesh } from "./SplatMesh"; -import { PlyReader } from "./ply"; -import { withWorker } from "./splatWorker"; -import { decompressPartialGzip, getTextureSize } from "./utils"; - -// SplatLoader implements the THREE.Loader interface and supports loading a variety -// of different Gsplat file formats. Formats .PLY and .SPZ can be auto-detected -// from the file contents, while .SPLAT and .KSPLAT require either having the -// appropriate file extension as part of the path, or it can be explicitly set -// in the loader using the fileType property. - -export class SplatLoader extends Loader { - fileLoader: FileLoader; + DefaultSplatEncoding, + type SplatEncodingClass, + type UnpackResult, +} from "./encoding/encoder"; +import type { PcSogsJson, PcSogsV2Json } from "./formats/pcsogs"; +import { decompressPartialGzip } from "./utils"; + +export type SplatLoaderOptions = { + loadingManager: THREE.LoadingManager; + fileType: SplatFileType; + splatEncoding: SplatEncodingClass; +}; + +/** + * SplatLoader implements the THREE.Loader interface and supports loading a variety + * of different Gsplat file formats. Formats .PLY and .SPZ can be auto-detected + * from the file contents, while .SPLAT and .KSPLAT require either having the + * appropriate file extension as part of the path, or it can be explicitly set + * in the loader using the fileType property. + */ +export class SplatLoader extends THREE.Loader { fileType?: SplatFileType; - packedSplats?: PackedSplats; + splatEncoding: SplatEncodingClass; - constructor(manager?: LoadingManager) { - super(manager); - this.fileLoader = new FileLoader(manager); + constructor(options?: SplatLoaderOptions) { + super(options?.loadingManager); + this.fileType = options?.fileType; + this.splatEncoding = options?.splatEncoding ?? DefaultSplatEncoding; } load( url: string, - onLoad?: (decoded: PackedSplats) => void, + onLoad?: (decoded: Splat) => void, onProgress?: (event: ProgressEvent) => void, onError?: (error: unknown) => void, ) { @@ -114,22 +121,12 @@ export class SplatLoader extends Loader { await Promise.all(promises); if (onLoad) { - const splatEncoding = - this.packedSplats?.splatEncoding ?? DEFAULT_SPLAT_ENCODING; - const decoded = await unpackSplats({ - input, + const splat = await this.parseAsync(input, { extraFiles, fileType, - pathOrUrl: resolvedURL, - splatEncoding, + fileName: resolvedURL, }); - - if (this.packedSplats) { - this.packedSplats.initialize(decoded); - onLoad(this.packedSplats); - } else { - onLoad(new PackedSplats(decoded)); - } + onLoad(splat); } }) .catch((error) => { @@ -144,7 +141,7 @@ export class SplatLoader extends Loader { async loadAsync( url: string, onProgress?: (event: ProgressEvent) => void, - ): Promise { + ): Promise { return new Promise((resolve, reject) => { this.load( url, @@ -157,8 +154,23 @@ export class SplatLoader extends Loader { }); } - parse(packedSplats: PackedSplats): SplatMesh { - return new SplatMesh({ packedSplats }); + async parseAsync( + input: ArrayBuffer, + options?: { + extraFiles?: Record; + fileType?: SplatFileType; + fileName?: string; + }, + ): Promise { + const decoded = await unpackSplats( + input, + this.splatEncoding.encodingName, + options?.extraFiles, + options?.fileType, + options?.fileName, + ); + + return new Splat(this.splatEncoding.fromTransferable(decoded.unpacked)); } } @@ -288,65 +300,6 @@ export function getSplatFileTypeFromPath( return undefined; } -export type PcSogsJson = { - means: { - shape: number[]; - dtype: string; - mins: number[]; - maxs: number[]; - files: string[]; - }; - scales: { - shape: number[]; - dtype: string; - mins: number[]; - maxs: number[]; - files: string[]; - }; - quats: { shape: number[]; dtype: string; encoding?: string; files: string[] }; - sh0: { - shape: number[]; - dtype: string; - mins: number[]; - maxs: number[]; - files: string[]; - }; - shN?: { - shape: number[]; - dtype: string; - mins: number; - maxs: number; - quantization: number; - files: string[]; - }; -}; - -export type PcSogsV2Json = { - version: 2; - count: number; - antialias?: boolean; - means: { - mins: number[]; - maxs: number[]; - files: string[]; - }; - scales: { - codebook: number[]; - files: string[]; - }; - quats: { files: string[] }; - sh0: { - codebook: number[]; - files: string[]; - }; - shN?: { - count: number; - bands: number; - codebook: number[]; - files: string[]; - }; -}; - export function isPcSogs(input: ArrayBuffer | Uint8Array | string): boolean { // Returns true if the input seems to be a valid PC SOGS file return tryPcSogs(input) !== undefined; @@ -449,23 +402,22 @@ export function tryPcSogsZip( } } -export async function unpackSplats({ - input, - extraFiles, - fileType, - pathOrUrl, - splatEncoding, -}: { - input: Uint8Array | ArrayBuffer; - extraFiles?: Record; - fileType?: SplatFileType; - pathOrUrl?: string; - splatEncoding?: SplatEncoding; -}): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra?: Record; -}> { +const SPLAT_FILE_TYPE_TO_RPC = { + [SplatFileType.PLY]: "decodePly", + [SplatFileType.SPZ]: "decodeSpz", + [SplatFileType.SPLAT]: "decodeAntiSplat", + [SplatFileType.KSPLAT]: "decodeKsplat", + [SplatFileType.PCSOGS]: "decodePcSogs", + [SplatFileType.PCSOGSZIP]: "decodePcSogsZip", +} as const satisfies Partial>; + +export async function unpackSplats( + input: Uint8Array | ArrayBuffer, + encodingName: string, + extraFiles?: Record, + fileType?: SplatFileType, + pathOrUrl?: string, +): Promise { const fileBytes = input instanceof ArrayBuffer ? new Uint8Array(input) : input; let splatFileType = fileType; @@ -476,262 +428,15 @@ export async function unpackSplats({ } } - switch (splatFileType) { - case SplatFileType.PLY: { - const ply = new PlyReader({ fileBytes }); - await ply.parseHeader(); - const numSplats = ply.numSplats; - const maxSplats = getTextureSize(numSplats).maxSplats; - const args = { - fileBytes, - packedArray: new Uint32Array(maxSplats * 4), - splatEncoding, - }; - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "unpackPly", - args, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.SPZ: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodeSpz", - { - fileBytes, - splatEncoding, - }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.SPLAT: { - return await withWorker(async (worker) => { - const { packedArray, numSplats } = (await worker.call( - "decodeAntiSplat", - { - fileBytes, - splatEncoding, - }, - )) as { packedArray: Uint32Array; numSplats: number }; - return { packedArray, numSplats }; - }); - } - case SplatFileType.KSPLAT: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodeKsplat", - { fileBytes, splatEncoding }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.PCSOGS: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodePcSogs", - { fileBytes, extraFiles, splatEncoding }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - case SplatFileType.PCSOGSZIP: { - return await withWorker(async (worker) => { - const { packedArray, numSplats, extra } = (await worker.call( - "decodePcSogsZip", - { fileBytes, splatEncoding }, - )) as { - packedArray: Uint32Array; - numSplats: number; - extra: Record; - }; - return { packedArray, numSplats, extra }; - }); - } - default: { - throw new Error(`Unknown splat file type: ${splatFileType}`); - } - } -} - -export class SplatData { - numSplats: number; - maxSplats: number; - centers: Float32Array; - scales: Float32Array; - quaternions: Float32Array; - opacities: Float32Array; - colors: Float32Array; - sh1?: Float32Array; - sh2?: Float32Array; - sh3?: Float32Array; - - constructor({ maxSplats = 1 }: { maxSplats?: number } = {}) { - this.numSplats = 0; - this.maxSplats = getTextureSize(maxSplats).maxSplats; - this.centers = new Float32Array(this.maxSplats * 3); - this.scales = new Float32Array(this.maxSplats * 3); - this.quaternions = new Float32Array(this.maxSplats * 4); - this.opacities = new Float32Array(this.maxSplats); - this.colors = new Float32Array(this.maxSplats * 3); - } - - pushSplat(): number { - const index = this.numSplats; - this.ensureIndex(index); - this.numSplats += 1; - return index; - } - - unpushSplat(index: number) { - if (index === this.numSplats - 1) { - this.numSplats -= 1; - } else { - throw new Error("Cannot unpush splat from non-last position"); - } - } - - ensureCapacity(numSplats: number) { - if (numSplats > this.maxSplats) { - const targetSplats = Math.max(numSplats, this.maxSplats * 2); - const newCenters = new Float32Array(targetSplats * 3); - const newScales = new Float32Array(targetSplats * 3); - const newQuaternions = new Float32Array(targetSplats * 4); - const newOpacities = new Float32Array(targetSplats); - const newColors = new Float32Array(targetSplats * 3); - newCenters.set(this.centers); - newScales.set(this.scales); - newQuaternions.set(this.quaternions); - newOpacities.set(this.opacities); - newColors.set(this.colors); - this.centers = newCenters; - this.scales = newScales; - this.quaternions = newQuaternions; - this.opacities = newOpacities; - this.colors = newColors; - - if (this.sh1) { - const newSh1 = new Float32Array(targetSplats * 9); - newSh1.set(this.sh1); - this.sh1 = newSh1; - } - if (this.sh2) { - const newSh2 = new Float32Array(targetSplats * 15); - newSh2.set(this.sh2); - this.sh2 = newSh2; - } - if (this.sh3) { - const newSh3 = new Float32Array(targetSplats * 21); - newSh3.set(this.sh3); - this.sh3 = newSh3; - } - - this.maxSplats = targetSplats; - } - } - - ensureIndex(index: number) { - this.ensureCapacity(index + 1); - } - - setCenter(index: number, x: number, y: number, z: number) { - this.centers[index * 3] = x; - this.centers[index * 3 + 1] = y; - this.centers[index * 3 + 2] = z; + if (!splatFileType) { + throw new Error(`Unknown splat file type: ${splatFileType}`); } - setScale(index: number, scaleX: number, scaleY: number, scaleZ: number) { - this.scales[index * 3] = scaleX; - this.scales[index * 3 + 1] = scaleY; - this.scales[index * 3 + 2] = scaleZ; - } - - setQuaternion(index: number, x: number, y: number, z: number, w: number) { - this.quaternions[index * 4] = x; - this.quaternions[index * 4 + 1] = y; - this.quaternions[index * 4 + 2] = z; - this.quaternions[index * 4 + 3] = w; - } - - setOpacity(index: number, opacity: number) { - this.opacities[index] = opacity; - } - - setColor(index: number, r: number, g: number, b: number) { - this.colors[index * 3] = r; - this.colors[index * 3 + 1] = g; - this.colors[index * 3 + 2] = b; - } - - setSh1(index: number, sh1: Float32Array) { - if (!this.sh1) { - this.sh1 = new Float32Array(this.maxSplats * 9); - } - for (let j = 0; j < 9; ++j) { - this.sh1[index * 9 + j] = sh1[j]; - } - } - - setSh2(index: number, sh2: Float32Array) { - if (!this.sh2) { - this.sh2 = new Float32Array(this.maxSplats * 15); - } - for (let j = 0; j < 15; ++j) { - this.sh2[index * 15 + j] = sh2[j]; - } - } - - setSh3(index: number, sh3: Float32Array) { - if (!this.sh3) { - this.sh3 = new Float32Array(this.maxSplats * 21); - } - for (let j = 0; j < 21; ++j) { - this.sh3[index * 21 + j] = sh3[j]; - } - } -} - -export async function transcodeSpz( - input: TranscodeSpzInput, -): Promise<{ input: TranscodeSpzInput; fileBytes: Uint8Array }> { - return await withWorker(async (worker) => { - const result = (await worker.call("transcodeSpz", input)) as { - input: TranscodeSpzInput; - fileBytes: Uint8Array; - }; - return result; + const decodeRpc = SPLAT_FILE_TYPE_TO_RPC[splatFileType]; + return await withWorkerCall(decodeRpc, { + fileBytes, + extraFiles, + encoder: encodingName, + encoderOptions: {}, }); } - -export type FileInput = { - fileBytes: Uint8Array; - fileType?: SplatFileType; - pathOrUrl?: string; - transform?: { translate?: number[]; quaternion?: number[]; scale?: number }; -}; - -export type TranscodeSpzInput = { - inputs: FileInput[]; - maxSh?: number; - clipXyz?: { min: number[]; max: number[] }; - fractionalBits?: number; - opacityThreshold?: number; -}; diff --git a/src/SplatMesh.ts b/src/SplatMesh.ts deleted file mode 100644 index 3ab66cf..0000000 --- a/src/SplatMesh.ts +++ /dev/null @@ -1,974 +0,0 @@ -import * as THREE from "three"; - -import init_wasm, { raycast_splats } from "spark-internal-rs"; -import { - DEFAULT_SPLAT_ENCODING, - PackedSplats, - type SplatEncoding, -} from "./PackedSplats"; -import { type RgbaArray, readRgbaArray } from "./RgbaArray"; -import { SparkRenderer } from "./SparkRenderer"; -import { SplatEdit, SplatEditSdf, SplatEdits } from "./SplatEdit"; -import { - type GsplatModifier, - SplatGenerator, - SplatTransformer, -} from "./SplatGenerator"; -import type { SplatFileType } from "./SplatLoader"; -import type { SplatSkinning } from "./SplatSkinning"; -import { LN_SCALE_MAX, LN_SCALE_MIN } from "./defines"; -import { - DynoFloat, - DynoUsampler2DArray, - type DynoVal, - DynoVec4, - Gsplat, - add, - combineGsplat, - defineGsplat, - dyno, - dynoBlock, - dynoConst, - extendVec, - mul, - normalize, - readPackedSplat, - split, - splitGsplat, - sub, - unindent, - unindentLines, -} from "./dyno"; -import { getTextureSize } from "./utils"; - -export type SplatMeshOptions = { - // URL to fetch a Gaussian splat file from(supports .ply, .splat, .ksplat, - // .spz formats). (default: undefined) - url?: string; - // Raw bytes of a Gaussian splat file to decode directly instead of fetching - // from URL. (default: undefined) - fileBytes?: Uint8Array | ArrayBuffer; - // Override the file type detection for formats that can't be reliably - // auto-detected (.splat, .ksplat). (default: undefined auto-detects other - // formats from file contents) - fileType?: SplatFileType; - // File name to use for type detection. (default: undefined) - fileName?: string; - // Use an existing PackedSplats object as the source instead of loading from - // a file. Can be used to share a collection of Gsplats among multiple SplatMeshes - // (default: undefined creates a new empty PackedSplats or decoded from a - // data source above) - packedSplats?: PackedSplats; - // Reserve space for at least this many splats when constructing the mesh - // initially. (default: determined by file) - maxSplats?: number; - // Callback function to programmatically create splats at initialization - // in provided PackedSplats. (default: undefined) - constructSplats?: (splats: PackedSplats) => Promise | void; - // Callback function that is called when mesh initialization is complete. - // (default: undefined) - onLoad?: (mesh: SplatMesh) => Promise | void; - // Controls whether SplatEdits have any effect on this mesh. (default: true) - editable?: boolean; - // Callback function that is called every frame to update the mesh. - // Call mesh.updateVersion() if splats need to be regenerated due to some change. - // Calling updateVersion() is not necessary for object transformations, recoloring, - // or opacity adjustments as these are auto-detected. (default: undefined) - onFrame?: ({ - mesh, - time, - deltaTime, - }: { mesh: SplatMesh; time: number; deltaTime: number }) => void; - // Gsplat modifier to apply in object-space before any transformations. - // A GsplatModifier is a dyno shader-graph block that transforms an input - // gsplat: DynoVal to an output gsplat: DynoVal with gsplat.center - // coordinate in object-space. (default: undefined) - objectModifier?: GsplatModifier; - // Gsplat modifier to apply in world-space after transformations. - // (default: undefined) - worldModifier?: GsplatModifier; - // Override the default splat encoding ranges for the PackedSplats. - // (default: undefined) - splatEncoding?: SplatEncoding; -}; - -export type SplatMeshContext = { - transform: SplatTransformer; - viewToWorld: SplatTransformer; - worldToView: SplatTransformer; - viewToObject: SplatTransformer; - recolor: DynoVec4; - time: DynoFloat; - deltaTime: DynoFloat; -}; - -export class SplatMesh extends SplatGenerator { - // A Promise you can await to ensure fetching, parsing, - // and initialization has completed - initialized: Promise; - // A boolean indicating whether initialization is complete - isInitialized = false; - - // If you modify packedSplats you should set - // splatMesh.packedSplats.needsUpdate = true to signal to Three.js that it - // should re-upload the data to the underlying texture. Use this sparingly with - // objects with smaller Gsplat counts as it requires a CPU-GPU data transfer for - // each frame. Thousands to tens of thousands of Gsplats ir fine. (See hands.ts - // for an example of rendering "Gsplat hands" in WebXR using this technique.) - packedSplats: PackedSplats; - - // A THREE.Color that can be used to tint all splats in the mesh. - // (default: new THREE.Color(1, 1, 1)) - recolor: THREE.Color = new THREE.Color(1, 1, 1); - // Global opacity multiplier for all splats in the mesh. (default: 1) - opacity = 1; - - // A SplatMeshContext consisting of useful scene and object dyno uniforms that can - // be used to in the Gsplat processing pipeline, for example via objectModifier and - // worldModifier. (created on construction) - context: SplatMeshContext; - onFrame?: ({ - mesh, - time, - deltaTime, - }: { mesh: SplatMesh; time: number; deltaTime: number }) => void; - - objectModifier?: GsplatModifier; - worldModifier?: GsplatModifier; - // Set to true to have the viewToObject property in context be updated each frame. - // If the mesh has extra.sh1 (first order spherical harmonics directional lighting) - // this property will always be updated. (default: false) - enableViewToObject = false; - // Set to true to have context.viewToWorld updated each frame. (default: false) - enableViewToWorld = false; - // Set to true to have context.worldToView updated each frame. (default: false) - enableWorldToView = false; - - // Optional SplatSkinning instance for animating splats with dual-quaternion - // skeletal animation. (default: null) - skinning: SplatSkinning | null = null; - - // Optional list of SplatEdits to apply to the mesh. If null, any SplatEdit - // children in the scene graph will be added automatically. (default: null) - edits: SplatEdit[] | null = null; - editable: boolean; - // Compiled SplatEdits for applying SDF edits to splat RGBA + centers - private rgbaDisplaceEdits: SplatEdits | null = null; - // Optional RgbaArray to overwrite splat RGBA values with custom values. - // Useful for "baking" RGB and opacity edits into the SplatMesh. (default: null) - splatRgba: RgbaArray | null = null; - - // Maximum Spherical Harmonics level to use. Call updateGenerator() - // after changing. (default: 3) - maxSh = 3; - - constructor(options: SplatMeshOptions = {}) { - const transform = new SplatTransformer(); - const viewToWorld = new SplatTransformer(); - const worldToView = new SplatTransformer(); - const viewToObject = new SplatTransformer(); - const recolor = new DynoVec4({ - value: new THREE.Vector4( - Number.NEGATIVE_INFINITY, - Number.NEGATIVE_INFINITY, - Number.NEGATIVE_INFINITY, - Number.NEGATIVE_INFINITY, - ), - }); - const time = new DynoFloat({ value: 0 }); - const deltaTime = new DynoFloat({ value: 0 }); - const context = { - transform, - viewToWorld, - worldToView, - viewToObject, - recolor, - time, - deltaTime, - }; - - super({ - update: ({ time, deltaTime, viewToWorld, globalEdits }) => - this.update({ time, deltaTime, viewToWorld, globalEdits }), - }); - - this.packedSplats = options.packedSplats ?? new PackedSplats(); - this.packedSplats.splatEncoding = options.splatEncoding ?? { - ...DEFAULT_SPLAT_ENCODING, - }; - this.numSplats = this.packedSplats.numSplats; - this.editable = options.editable ?? true; - this.onFrame = options.onFrame; - - this.context = context; - this.objectModifier = options.objectModifier; - this.worldModifier = options.worldModifier; - - this.updateGenerator(); - - if ( - options.url || - options.fileBytes || - options.constructSplats || - (options.packedSplats && !options.packedSplats.isInitialized) - ) { - // We need to initialize asynchronously given the options - this.initialized = this.asyncInitialize(options).then(async () => { - this.updateGenerator(); - - this.isInitialized = true; - if (options.onLoad) { - const maybePromise = options.onLoad(this); - if (maybePromise instanceof Promise) { - await maybePromise; - } - } - return this; - }); - } else { - this.isInitialized = true; - this.initialized = Promise.resolve(this); - if (options.onLoad) { - const maybePromise = options.onLoad(this); - // If onLoad returns a promise, wait for it to complete - if (maybePromise instanceof Promise) { - this.initialized = maybePromise.then(() => this); - } - } - } - - this.add(createRendererDetectionMesh()); - } - - async asyncInitialize(options: SplatMeshOptions) { - const { - url, - fileBytes, - fileType, - fileName, - maxSplats, - constructSplats, - splatEncoding, - } = options; - if (url || fileBytes || constructSplats) { - const packedSplatsOptions = { - url, - fileBytes, - fileType, - fileName, - maxSplats, - construct: constructSplats, - splatEncoding, - }; - this.packedSplats.reinitialize(packedSplatsOptions); - } - if (this.packedSplats) { - await this.packedSplats.initialized; - this.numSplats = this.packedSplats.numSplats; - this.updateGenerator(); - } - } - - static staticInitialized = SplatMesh.staticInitialize(); - static isStaticInitialized = false; - - static dynoTime = new DynoFloat({ value: 0 }); - - static async staticInitialize() { - await init_wasm(); - SplatMesh.isStaticInitialized = true; - } - - // Creates a new Gsplat with the provided parameters (all values in "float" space, - // i.e. 0-1 for opacity and color) and adds it to the end of the packedSplats, - // increasing numSplats by 1. If necessary, reallocates the buffer with an exponential - // doubling strategy to fit the new data, so it's fairly efficient to just - // pushSplat(...) each Gsplat you want to create in a loop. - pushSplat( - center: THREE.Vector3, - scales: THREE.Vector3, - quaternion: THREE.Quaternion, - opacity: number, - color: THREE.Color, - ) { - this.packedSplats.pushSplat(center, scales, quaternion, opacity, color); - } - - // This method iterates over all Gsplats in this instance's packedSplats, - // invoking the provided callback with index: number in 0..=(this.numSplats-1) and - // center: THREE.Vector3, scales: THREE.Vector3, quaternion: THREE.Quaternion, - // opacity: number (0..1), and color: THREE.Color (rgb values in 0..1). - // Note that the objects passed in as center etc. are the same for every callback - // invocation: these objects are reused for efficiency. Changing these values has - // no effect as they are decoded/unpacked copies of the underlying data. To update - // the packedSplats, call .packedSplats.setSplat(index, center, scales, - // quaternion, opacity, color). - forEachSplat( - callback: ( - index: number, - center: THREE.Vector3, - scales: THREE.Vector3, - quaternion: THREE.Quaternion, - opacity: number, - color: THREE.Color, - ) => void, - ) { - this.packedSplats.forEachSplat(callback); - } - - // Call this when you are finished with the SplatMesh and want to free - // any buffers it holds (via packedSplats). - dispose() { - this.packedSplats.dispose(); - } - - // Returns axis-aligned bounding box of the SplatMesh. If centers_only is true, - // only the centers of the splats are used to compute the bounding box. - // IMPORTANT: This should only be called after the SplatMesh is initialized. - getBoundingBox(centers_only = true) { - if (!this.initialized) { - throw new Error( - "Cannot get bounding box before SplatMesh is initialized", - ); - } - const minVec = new THREE.Vector3( - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - Number.POSITIVE_INFINITY, - ); - const maxVec = new THREE.Vector3( - Number.NEGATIVE_INFINITY, - Number.NEGATIVE_INFINITY, - Number.NEGATIVE_INFINITY, - ); - const corners = new THREE.Vector3(); - const signs = [-1, 1]; - this.packedSplats.forEachSplat( - (_index, center, scales, quaternion, _opacity, _color) => { - if (centers_only) { - minVec.min(center); - maxVec.max(center); - } else { - // Get the 8 corners of the AABB in local space - for (const x of signs) { - for (const y of signs) { - for (const z of signs) { - corners.set(x * scales.x, y * scales.y, z * scales.z); - // Transform corner by rotation and position - corners.applyQuaternion(quaternion); - corners.add(center); - minVec.min(corners); - maxVec.max(corners); - } - } - } - } - }, - ); - const box = new THREE.Box3(minVec, maxVec); - return box; - } - - constructGenerator(context: SplatMeshContext) { - const { transform, viewToObject, recolor } = context; - const generator = dynoBlock( - { index: "int" }, - { gsplat: Gsplat }, - ({ index }) => { - if (!index) { - throw new Error("index is undefined"); - } - // Read a Gsplat from the PackedSplats template - let gsplat = readPackedSplat(this.packedSplats.dyno, index); - - if (this.maxSh >= 1) { - // Inject lighting from SH1..SH3 - const { sh1Texture, sh2Texture, sh3Texture } = - this.ensureShTextures(); - if (sh1Texture) { - //Calculate view direction in object space - const viewCenterInObject = viewToObject.translate; - const { center } = splitGsplat(gsplat).outputs; - const viewDir = normalize(sub(center, viewCenterInObject)); - - function rescaleSh( - sNorm: DynoVal<"vec3">, - minMax: DynoVal<"vec2">, - ) { - const { x: min, y: max } = split(minMax).outputs; - const mid = mul(add(min, max), dynoConst("float", 0.5)); - const scale = mul(sub(max, min), dynoConst("float", 0.5)); - return add(mid, mul(sNorm, scale)); - } - - // Evaluate Spherical Harmonics - const sh1Snorm = evaluateSH1(gsplat, sh1Texture, viewDir); - let rgb = rescaleSh(sh1Snorm, this.packedSplats.dynoSh1MinMax); - if (this.maxSh >= 2 && sh2Texture) { - const sh2Snorm = evaluateSH2(gsplat, sh2Texture, viewDir); - rgb = add( - rgb, - rescaleSh(sh2Snorm, this.packedSplats.dynoSh2MinMax), - ); - } - if (this.maxSh >= 3 && sh3Texture) { - const sh3Snorm = evaluateSH3(gsplat, sh3Texture, viewDir); - rgb = add( - rgb, - rescaleSh(sh3Snorm, this.packedSplats.dynoSh3MinMax), - ); - } - - // Flash off for 0.3 / 1.0 sec for debugging - // const fractTime = fract(SplatMesh.dynoTime); - // const lessThan05 = lessThan(fractTime, dynoConst("float", 0.3)); - // rgb = select(lessThan05, dynoConst("vec3", new THREE.Vector3()), rgb); - - // Add SH lighting to RGBA - let { rgba } = splitGsplat(gsplat).outputs; - rgba = add(rgba, extendVec(rgb, dynoConst("float", 0.0))); - gsplat = combineGsplat({ gsplat, rgba }); - } - } - - if (this.splatRgba) { - // Overwrite RGBA with baked RGBA values - const rgba = readRgbaArray(this.splatRgba.dyno, index); - gsplat = combineGsplat({ gsplat, rgba }); - } - - if (this.skinning) { - // Transform according to bones + skinning weights - gsplat = this.skinning.modify(gsplat); - } - - if (this.objectModifier) { - // Inject object-space Gsplat modifier dyno - gsplat = this.objectModifier.apply({ gsplat }).gsplat; - } - - // Transform from object to world-space - gsplat = transform.applyGsplat(gsplat); - - // Apply any global recoloring and opacity - const recolorRgba = mul(recolor, splitGsplat(gsplat).outputs.rgba); - gsplat = combineGsplat({ gsplat, rgba: recolorRgba }); - - if (this.rgbaDisplaceEdits) { - // Apply RGBA edit layer SDFs - gsplat = this.rgbaDisplaceEdits.modify(gsplat); - } - if (this.worldModifier) { - // Inject world-space Gsplat modifier dyno - gsplat = this.worldModifier.apply({ gsplat }).gsplat; - } - - // We're done! Output resulting Gsplat - return { gsplat }; - }, - ); - this.generator = generator; - } - - // Call this whenever something changes in the Gsplat processing pipeline, - // for example changing maxSh or updating objectModifier or worldModifier. - // Compiled generators are cached for efficiency and re-use when the same - // pipeline structure emerges after successive changes. - updateGenerator() { - this.constructGenerator(this.context); - } - - // This is called automatically by SparkRenderer and you should not have to - // call it. It updates parameters for the generated pipeline and calls - // updateGenerator() if the pipeline needs to change. - update({ - time, - viewToWorld, - deltaTime, - globalEdits, - }: { - time: number; - viewToWorld: THREE.Matrix4; - deltaTime: number; - globalEdits: SplatEdit[]; - }) { - this.numSplats = this.packedSplats.numSplats; - this.context.time.value = time; - this.context.deltaTime.value = deltaTime; - SplatMesh.dynoTime.value = time; - - const { transform, viewToObject, recolor } = this.context; - let updated = transform.update(this); - - if ( - this.context.viewToWorld.updateFromMatrix(viewToWorld) && - this.enableViewToWorld - ) { - updated = true; - } - const worldToView = viewToWorld.clone().invert(); - if ( - this.context.worldToView.updateFromMatrix(worldToView) && - this.enableWorldToView - ) { - updated = true; - } - - const objectToWorld = new THREE.Matrix4().compose( - transform.translate.value, - transform.rotate.value, - new THREE.Vector3().setScalar(transform.scale.value), - ); - const worldToObject = objectToWorld.invert(); - const viewToObjectMatrix = worldToObject.multiply(viewToWorld); - if ( - viewToObject.updateFromMatrix(viewToObjectMatrix) && - (this.enableViewToObject || this.packedSplats.extra.sh1) - ) { - // Only trigger update if we have view-dependent spherical harmonics - updated = true; - } - - const newRecolor = new THREE.Vector4( - this.recolor.r, - this.recolor.g, - this.recolor.b, - this.opacity, - ); - if (!newRecolor.equals(recolor.value)) { - recolor.value.copy(newRecolor); - updated = true; - } - - const edits = this.editable ? (this.edits ?? []).concat(globalEdits) : []; - if (this.editable && !this.edits) { - // If we haven't set any explicit edits, add any child SplatEdits - this.traverseVisible((node) => { - if (node instanceof SplatEdit) { - edits.push(node); - } - }); - } - - edits.sort((a, b) => a.ordering - b.ordering); - const editsSdfs = edits.map((edit) => { - if (edit.sdfs != null) { - return { edit, sdfs: edit.sdfs }; - } - const sdfs: SplatEditSdf[] = []; - edit.traverseVisible((node) => { - if (node instanceof SplatEditSdf) { - sdfs.push(node); - } - }); - return { edit, sdfs }; - }); - - if (editsSdfs.length > 0 && !this.rgbaDisplaceEdits) { - const edits = editsSdfs.length; - const sdfs = editsSdfs.reduce( - (total, edit) => total + edit.sdfs.length, - 0, - ); - this.rgbaDisplaceEdits = new SplatEdits({ - maxEdits: edits, - maxSdfs: sdfs, - }); - this.updateGenerator(); - } - if (this.rgbaDisplaceEdits) { - const editResult = this.rgbaDisplaceEdits.update(editsSdfs); - updated ||= editResult.updated; - if (editResult.dynoUpdated) { - this.updateGenerator(); - } - } - - if (updated) { - this.updateVersion(); - } - - this.onFrame?.({ mesh: this, time, deltaTime }); - } - - // This method conforms to the standard THREE.Raycaster API, performing object-ray - // intersections using this method to populate the provided intersects[] array - // with each intersection point. - raycast( - raycaster: THREE.Raycaster, - intersects: { - distance: number; - point: THREE.Vector3; - object: THREE.Object3D; - }[], - ) { - if (!this.packedSplats.packedArray || !this.packedSplats.numSplats) { - return; - } - - const { near, far, ray } = raycaster; - const worldToMesh = this.matrixWorld.clone().invert(); - const worldToMeshRot = new THREE.Matrix3().setFromMatrix4(worldToMesh); - const origin = ray.origin.clone().applyMatrix4(worldToMesh); - const direction = ray.direction.clone().applyMatrix3(worldToMeshRot); - const scales = new THREE.Vector3(); - worldToMesh.decompose(new THREE.Vector3(), new THREE.Quaternion(), scales); - const scale = (scales.x * scales.y * scales.z) ** (1.0 / 3.0); - - const RAYCAST_ELLIPSOID = true; - const distances = raycast_splats( - origin.x, - origin.y, - origin.z, - direction.x, - direction.y, - direction.z, - near, - far, - this.packedSplats.numSplats, - this.packedSplats.packedArray, - RAYCAST_ELLIPSOID, - this.packedSplats.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, - this.packedSplats.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, - ); - - for (const distance of distances) { - const point = ray.direction - .clone() - .multiplyScalar(distance) - .add(ray.origin); - intersects.push({ - distance, - point, - object: this, - }); - } - } - - private ensureShTextures(): { - sh1Texture?: DynoUsampler2DArray<"sh1", THREE.DataArrayTexture>; - sh2Texture?: DynoUsampler2DArray<"sh2", THREE.DataArrayTexture>; - sh3Texture?: DynoUsampler2DArray<"sh3", THREE.DataArrayTexture>; - } { - // Ensure we have textures for SH1..SH3 if we have data - if (!this.packedSplats.extra.sh1) { - return {}; - } - - let sh1Texture = this.packedSplats.extra.sh1Texture as - | DynoUsampler2DArray<"sh1", THREE.DataArrayTexture> - | undefined; - if (!sh1Texture) { - let sh1 = this.packedSplats.extra.sh1 as Uint32Array; - const { width, height, depth, maxSplats } = getTextureSize( - sh1.length / 2, - ); - if (sh1.length < maxSplats * 2) { - const newSh1 = new Uint32Array(maxSplats * 2); - newSh1.set(sh1); - this.packedSplats.extra.sh1 = newSh1; - sh1 = newSh1; - } - - const texture = new THREE.DataArrayTexture(sh1, width, height, depth); - texture.format = THREE.RGIntegerFormat; - texture.type = THREE.UnsignedIntType; - texture.internalFormat = "RG32UI"; - texture.needsUpdate = true; - - sh1Texture = new DynoUsampler2DArray({ - value: texture, - key: "sh1", - }); - this.packedSplats.extra.sh1Texture = sh1Texture; - } - - if (!this.packedSplats.extra.sh2) { - return { sh1Texture }; - } - - let sh2Texture = this.packedSplats.extra.sh2Texture as - | DynoUsampler2DArray<"sh2", THREE.DataArrayTexture> - | undefined; - if (!sh2Texture) { - let sh2 = this.packedSplats.extra.sh2 as Uint32Array; - const { width, height, depth, maxSplats } = getTextureSize( - sh2.length / 4, - ); - if (sh2.length < maxSplats * 4) { - const newSh2 = new Uint32Array(maxSplats * 4); - newSh2.set(sh2); - this.packedSplats.extra.sh2 = newSh2; - sh2 = newSh2; - } - - const texture = new THREE.DataArrayTexture(sh2, width, height, depth); - texture.format = THREE.RGBAIntegerFormat; - texture.type = THREE.UnsignedIntType; - texture.internalFormat = "RGBA32UI"; - texture.needsUpdate = true; - - sh2Texture = new DynoUsampler2DArray({ - value: texture, - key: "sh2", - }); - this.packedSplats.extra.sh2Texture = sh2Texture; - } - - if (!this.packedSplats.extra.sh3) { - return { sh1Texture, sh2Texture }; - } - - let sh3Texture = this.packedSplats.extra.sh3Texture as - | DynoUsampler2DArray<"sh3", THREE.DataArrayTexture> - | undefined; - if (!sh3Texture) { - let sh3 = this.packedSplats.extra.sh3 as Uint32Array; - const { width, height, depth, maxSplats } = getTextureSize( - sh3.length / 4, - ); - if (sh3.length < maxSplats * 4) { - const newSh3 = new Uint32Array(maxSplats * 4); - newSh3.set(sh3); - this.packedSplats.extra.sh3 = newSh3; - sh3 = newSh3; - } - - const texture = new THREE.DataArrayTexture(sh3, width, height, depth); - texture.format = THREE.RGBAIntegerFormat; - texture.type = THREE.UnsignedIntType; - texture.internalFormat = "RGBA32UI"; - texture.needsUpdate = true; - - sh3Texture = new DynoUsampler2DArray({ - value: texture, - key: "sh3", - }); - this.packedSplats.extra.sh3Texture = sh3Texture; - } - - return { sh1Texture, sh2Texture, sh3Texture }; - } -} - -const defineEvaluateSH1 = unindent(` - vec3 evaluateSH1(Gsplat gsplat, usampler2DArray sh1, vec3 viewDir) { - // Extract sint7 values packed into 2 x uint32 - uvec2 packed = texelFetch(sh1, splatTexCoord(gsplat.index), 0).rg; - vec3 sh1_0 = vec3(ivec3( - int(packed.x << 25u) >> 25, - int(packed.x << 18u) >> 25, - int(packed.x << 11u) >> 25 - )) / 63.0; - vec3 sh1_1 = vec3(ivec3( - int(packed.x << 4u) >> 25, - int((packed.x >> 3u) | (packed.y << 29u)) >> 25, - int(packed.y << 22u) >> 25 - )) / 63.0; - vec3 sh1_2 = vec3(ivec3( - int(packed.y << 15u) >> 25, - int(packed.y << 8u) >> 25, - int(packed.y << 1u) >> 25 - )) / 63.0; - - return sh1_0 * (-0.4886025 * viewDir.y) - + sh1_1 * (0.4886025 * viewDir.z) - + sh1_2 * (-0.4886025 * viewDir.x); - } -`); - -const defineEvaluateSH2 = unindent(` - vec3 evaluateSH2(Gsplat gsplat, usampler2DArray sh2, vec3 viewDir) { - // Extract sint8 values packed into 4 x uint32 - uvec4 packed = texelFetch(sh2, splatTexCoord(gsplat.index), 0); - vec3 sh2_0 = vec3(ivec3( - int(packed.x << 24u) >> 24, - int(packed.x << 16u) >> 24, - int(packed.x << 8u) >> 24 - )) / 127.0; - vec3 sh2_1 = vec3(ivec3( - int(packed.x) >> 24, - int(packed.y << 24u) >> 24, - int(packed.y << 16u) >> 24 - )) / 127.0; - vec3 sh2_2 = vec3(ivec3( - int(packed.y << 8u) >> 24, - int(packed.y) >> 24, - int(packed.z << 24u) >> 24 - )) / 127.0; - vec3 sh2_3 = vec3(ivec3( - int(packed.z << 16u) >> 24, - int(packed.z << 8u) >> 24, - int(packed.z) >> 24 - )) / 127.0; - vec3 sh2_4 = vec3(ivec3( - int(packed.w << 24u) >> 24, - int(packed.w << 16u) >> 24, - int(packed.w << 8u) >> 24 - )) / 127.0; - - return sh2_0 * (1.0925484 * viewDir.x * viewDir.y) - + sh2_1 * (-1.0925484 * viewDir.y * viewDir.z) - + sh2_2 * (0.3153915 * (2.0 * viewDir.z * viewDir.z - viewDir.x * viewDir.x - viewDir.y * viewDir.y)) - + sh2_3 * (-1.0925484 * viewDir.x * viewDir.z) - + sh2_4 * (0.5462742 * (viewDir.x * viewDir.x - viewDir.y * viewDir.y)); - } -`); - -const defineEvaluateSH3 = unindent(` - vec3 evaluateSH3(Gsplat gsplat, usampler2DArray sh3, vec3 viewDir) { - // Extract sint6 values packed into 4 x uint32 - uvec4 packed = texelFetch(sh3, splatTexCoord(gsplat.index), 0); - vec3 sh3_0 = vec3(ivec3( - int(packed.x << 26u) >> 26, - int(packed.x << 20u) >> 26, - int(packed.x << 14u) >> 26 - )) / 31.0; - vec3 sh3_1 = vec3(ivec3( - int(packed.x << 8u) >> 26, - int(packed.x << 2u) >> 26, - int((packed.x >> 4u) | (packed.y << 28u)) >> 26 - )) / 31.0; - vec3 sh3_2 = vec3(ivec3( - int(packed.y << 22u) >> 26, - int(packed.y << 16u) >> 26, - int(packed.y << 10u) >> 26 - )) / 31.0; - vec3 sh3_3 = vec3(ivec3( - int(packed.y << 4u) >> 26, - int((packed.y >> 2u) | (packed.z << 30u)) >> 26, - int(packed.z << 24u) >> 26 - )) / 31.0; - vec3 sh3_4 = vec3(ivec3( - int(packed.z << 18u) >> 26, - int(packed.z << 12u) >> 26, - int(packed.z << 6u) >> 26 - )) / 31.0; - vec3 sh3_5 = vec3(ivec3( - int(packed.z) >> 26, - int(packed.w << 26u) >> 26, - int(packed.w << 20u) >> 26 - )) / 31.0; - vec3 sh3_6 = vec3(ivec3( - int(packed.w << 14u) >> 26, - int(packed.w << 8u) >> 26, - int(packed.w << 2u) >> 26 - )) / 31.0; - - float xx = viewDir.x * viewDir.x; - float yy = viewDir.y * viewDir.y; - float zz = viewDir.z * viewDir.z; - float xy = viewDir.x * viewDir.y; - float yz = viewDir.y * viewDir.z; - float zx = viewDir.z * viewDir.x; - - return sh3_0 * (-0.5900436 * viewDir.y * (3.0 * xx - yy)) - + sh3_1 * (2.8906114 * xy * viewDir.z) + - + sh3_2 * (-0.4570458 * viewDir.y * (4.0 * zz - xx - yy)) - + sh3_3 * (0.3731763 * viewDir.z * (2.0 * zz - 3.0 * xx - 3.0 * yy)) - + sh3_4 * (-0.4570458 * viewDir.x * (4.0 * zz - xx - yy)) - + sh3_5 * (1.4453057 * viewDir.z * (xx - yy)) - + sh3_6 * (-0.5900436 * viewDir.x * (xx - 3.0 * yy)); - } -`); - -export function evaluateSH1( - gsplat: DynoVal, - sh1: DynoUsampler2DArray<"sh1", THREE.DataArrayTexture>, - viewDir: DynoVal<"vec3">, -): DynoVal<"vec3"> { - return dyno({ - inTypes: { gsplat: Gsplat, sh1: "usampler2DArray", viewDir: "vec3" }, - outTypes: { rgb: "vec3" }, - inputs: { gsplat, sh1, viewDir }, - globals: () => [defineGsplat, defineEvaluateSH1], - statements: ({ inputs, outputs }) => { - const statements = unindentLines(` - if (isGsplatActive(${inputs.gsplat}.flags)) { - ${outputs.rgb} = evaluateSH1(${inputs.gsplat}, ${inputs.sh1}, ${inputs.viewDir}); - } else { - ${outputs.rgb} = vec3(0.0); - } - `); - return statements; - }, - }).outputs.rgb; -} - -export function evaluateSH2( - gsplat: DynoVal, - sh2: DynoVal<"usampler2DArray">, - viewDir: DynoVal<"vec3">, -): DynoVal<"vec3"> { - return dyno({ - inTypes: { gsplat: Gsplat, sh2: "usampler2DArray", viewDir: "vec3" }, - outTypes: { rgb: "vec3" }, - inputs: { gsplat, sh2, viewDir }, - globals: () => [defineGsplat, defineEvaluateSH2], - statements: ({ inputs, outputs }) => - unindentLines(` - if (isGsplatActive(${inputs.gsplat}.flags)) { - ${outputs.rgb} = evaluateSH2(${inputs.gsplat}, ${inputs.sh2}, ${inputs.viewDir}); - } else { - ${outputs.rgb} = vec3(0.0); - } - `), - }).outputs.rgb; -} - -export function evaluateSH3( - gsplat: DynoVal, - sh3: DynoVal<"usampler2DArray">, - viewDir: DynoVal<"vec3">, -): DynoVal<"vec3"> { - return dyno({ - inTypes: { gsplat: Gsplat, sh3: "usampler2DArray", viewDir: "vec3" }, - outTypes: { rgb: "vec3" }, - inputs: { gsplat, sh3, viewDir }, - globals: () => [defineGsplat, defineEvaluateSH3], - statements: ({ inputs, outputs }) => - unindentLines(` - if (isGsplatActive(${inputs.gsplat}.flags)) { - ${outputs.rgb} = evaluateSH3(${inputs.gsplat}, ${inputs.sh3}, ${inputs.viewDir}); - } else { - ${outputs.rgb} = vec3(0.0); - } - `), - }).outputs.rgb; -} - -const EMPTY_GEOMETRY = new THREE.BufferGeometry(); -const EMPTY_MATERIAL = new THREE.ShaderMaterial(); - -// Creates an empty mesh to hook into Three.js rendering. -// This is used to detect if a SparkRenderer is present in the scene. -// If not, one will be injected automatically. -function createRendererDetectionMesh(): THREE.Mesh { - const mesh = new THREE.Mesh(EMPTY_GEOMETRY, EMPTY_MATERIAL); - mesh.frustumCulled = false; - mesh.onBeforeRender = function (renderer, scene) { - if (!scene.isScene) { - // The SplatMesh is part of render call that doesn't have a Scene at its root - // Don't auto-inject a renderer. - this.removeFromParent(); - return; - } - - // Check if the scene has a SparkRenderer instance - let hasSparkRenderer = false; - scene.traverse((c) => { - if (c instanceof SparkRenderer) { - hasSparkRenderer = true; - } - }); - - if (!hasSparkRenderer) { - // No spark renderer present in the scene, inject one. - scene.add(new SparkRenderer({ renderer })); - } - - // Remove mesh to stop checking - this.removeFromParent(); - }; - return mesh; -} diff --git a/src/SplatSkinning.ts b/src/SplatSkinning.ts deleted file mode 100644 index f5e90f1..0000000 --- a/src/SplatSkinning.ts +++ /dev/null @@ -1,298 +0,0 @@ -import * as THREE from "three"; - -// SplatSkinning is an experimental class that implements dual-quaternion -// skeletal animation for Gsplats. A skeletal animation system consists -// of a set of bones, each with a "rest" pose that consists of a position -// and orientation, and a weighting of up to 4 bones for each Gsplat. -// By moving and rotating the bones you can animate all the Gsplats like -// your would for a normal 3D animated mesh. -// Note that the dual-quaternion formulation assumes that mass/volume -// is conserved through these transformations, which helps avoid common -// issues with linear blend skinning such as joint collapse or bulging. -// However, it is not as good a fit for animations that involve explicit -// deformations, such as cartoon animations. - -import type { SplatMesh } from "./SplatMesh"; -import { - Dyno, - DynoUniform, - type DynoVal, - Gsplat, - unindent, - unindentLines, -} from "./dyno"; -import { getTextureSize } from "./utils"; - -export type SplatSkinningOptions = { - // Specifies the SplatMesh that will be animated. - mesh: SplatMesh; - // Overrides the number of Gsplats in the mesh that will be animated. - // (default: mesh.numSplats) - numSplats?: number; - // Set the number of bones used to animate the SplatMesh, with a maximum - // of 256 (in order to compactly encode the bone index). (default: 256) - numBones?: number; -}; - -export class SplatSkinning { - mesh: SplatMesh; - numSplats: number; - - // Store the skinning weights for each Gsplat, composed of a 4-vector - // of bone indices and weight - skinData: Uint16Array; - skinTexture: THREE.DataArrayTexture; - - numBones: number; - boneData: Float32Array; - boneTexture: THREE.DataTexture; - - uniform: DynoUniform; - - constructor(options: SplatSkinningOptions) { - this.mesh = options.mesh; - this.numSplats = options.numSplats ?? this.mesh.numSplats; - - const { width, height, depth, maxSplats } = getTextureSize(this.numSplats); - this.skinData = new Uint16Array(maxSplats * 4); - this.skinTexture = new THREE.DataArrayTexture( - this.skinData, - width, - height, - depth, - ); - this.skinTexture.format = THREE.RGBAIntegerFormat; - this.skinTexture.type = THREE.UnsignedShortType; - this.skinTexture.internalFormat = "RGBA16UI"; - this.skinTexture.needsUpdate = true; - - this.numBones = options.numBones ?? 256; - this.boneData = new Float32Array(this.numBones * 16); - this.boneTexture = new THREE.DataTexture( - this.boneData, - 4, - this.numBones, - THREE.RGBAFormat, - THREE.FloatType, - ); - this.boneTexture.internalFormat = "RGBA32F"; - this.boneTexture.needsUpdate = true; - - this.uniform = new DynoUniform({ - key: "skinning", - type: GsplatSkinning, - globals: () => [defineGsplatSkinning], - value: { - numSplats: this.numSplats, - numBones: this.numBones, - skinTexture: this.skinTexture, - boneTexture: this.boneTexture, - }, - }); - } - - // Apply the skeletal animation to a Gsplat in a dyno program. - modify(gsplat: DynoVal): DynoVal { - return applyGsplatSkinning(gsplat, this.uniform); - } - - // Set the "rest" pose for a bone with position and quaternion orientation. - setRestQuatPos( - boneIndex: number, - quat: THREE.Quaternion, - pos: THREE.Vector3, - ) { - const i16 = boneIndex * 16; - this.boneData[i16 + 0] = quat.x; - this.boneData[i16 + 1] = quat.y; - this.boneData[i16 + 2] = quat.z; - this.boneData[i16 + 3] = quat.w; - this.boneData[i16 + 4] = pos.x; - this.boneData[i16 + 5] = pos.y; - this.boneData[i16 + 6] = pos.z; - this.boneData[i16 + 7] = 0; - this.boneData[i16 + 8] = 0; - this.boneData[i16 + 9] = 0; - this.boneData[i16 + 10] = 0; - this.boneData[i16 + 11] = 1; - this.boneData[i16 + 12] = 0; - this.boneData[i16 + 13] = 0; - this.boneData[i16 + 14] = 0; - this.boneData[i16 + 15] = 0; - } - - // Set the "current" position and orientation of a bone. - setBoneQuatPos( - boneIndex: number, - quat: THREE.Quaternion, - pos: THREE.Vector3, - ) { - const i16 = boneIndex * 16; - const origQuat = new THREE.Quaternion( - this.boneData[i16 + 0], - this.boneData[i16 + 1], - this.boneData[i16 + 2], - this.boneData[i16 + 3], - ); - const origPos = new THREE.Vector3( - this.boneData[i16 + 4], - this.boneData[i16 + 5], - this.boneData[i16 + 6], - ); - - const relQuat = origQuat.clone().invert(); - const relPos = pos.clone().sub(origPos); - relPos.applyQuaternion(relQuat); - relQuat.multiply(quat); - const dual = new THREE.Quaternion( - relPos.x, - relPos.y, - relPos.z, - 0.0, - ).multiply(origQuat); - - this.boneData[i16 + 8] = relQuat.x; - this.boneData[i16 + 9] = relQuat.y; - this.boneData[i16 + 10] = relQuat.z; - this.boneData[i16 + 11] = relQuat.w; - this.boneData[i16 + 12] = 0.5 * dual.x; - this.boneData[i16 + 13] = 0.5 * dual.y; - this.boneData[i16 + 14] = 0.5 * dual.z; - this.boneData[i16 + 15] = 0.5 * dual.w; - } - - // Set up to 4 bone indices and weights for a Gsplat. For fewer than 4 bones, - // you can set the remaining weights to 0 (and index=0). - setSplatBones( - splatIndex: number, - boneIndices: THREE.Vector4, - weights: THREE.Vector4, - ) { - const i4 = splatIndex * 4; - this.skinData[i4 + 0] = - Math.min(255, Math.max(0, Math.round(weights.x * 255.0))) + - (boneIndices.x << 8); - this.skinData[i4 + 1] = - Math.min(255, Math.max(0, Math.round(weights.y * 255.0))) + - (boneIndices.y << 8); - this.skinData[i4 + 2] = - Math.min(255, Math.max(0, Math.round(weights.z * 255.0))) + - (boneIndices.z << 8); - this.skinData[i4 + 3] = - Math.min(255, Math.max(0, Math.round(weights.w * 255.0))) + - (boneIndices.w << 8); - } - - // Call this to indicate that the bones have changed and the Gsplats need to be - // re-generated with updated skinning. - updateBones() { - this.boneTexture.needsUpdate = true; - this.mesh.needsUpdate = true; - } -} - -// dyno program definitions for SplatSkinning - -export const GsplatSkinning = { type: "GsplatSkinning" } as { - type: "GsplatSkinning"; -}; - -export const defineGsplatSkinning = unindent(` - struct GsplatSkinning { - int numSplats; - int numBones; - usampler2DArray skinTexture; - sampler2D boneTexture; - }; -`); - -export const defineApplyGsplatSkinning = unindent(` - void applyGsplatSkinning( - int numSplats, int numBones, - usampler2DArray skinTexture, sampler2D boneTexture, - int splatIndex, inout vec3 center, inout vec4 quaternion - ) { - if ((splatIndex < 0) || (splatIndex >= numSplats)) { - return; - } - - uvec4 skinData = texelFetch(skinTexture, splatTexCoord(splatIndex), 0); - - float weights[4]; - weights[0] = float(skinData.x & 0xffu) / 255.0; - weights[1] = float(skinData.y & 0xffu) / 255.0; - weights[2] = float(skinData.z & 0xffu) / 255.0; - weights[3] = float(skinData.w & 0xffu) / 255.0; - - uint boneIndices[4]; - boneIndices[0] = (skinData.x >> 8u) & 0xffu; - boneIndices[1] = (skinData.y >> 8u) & 0xffu; - boneIndices[2] = (skinData.z >> 8u) & 0xffu; - boneIndices[3] = (skinData.w >> 8u) & 0xffu; - - vec4 quat = vec4(0.0); - vec4 dual = vec4(0.0); - for (int i = 0; i < 4; i++) { - if (weights[i] > 0.0) { - int boneIndex = int(boneIndices[i]); - vec4 boneQuat = vec4(0.0, 0.0, 0.0, 1.0); - vec4 boneDual = vec4(0.0); - if (boneIndex < numBones) { - boneQuat = texelFetch(boneTexture, ivec2(2, boneIndex), 0); - boneDual = texelFetch(boneTexture, ivec2(3, boneIndex), 0); - } - - if ((i > 0) && (dot(quat, boneQuat) < 0.0)) { - // Flip sign if next blend is pointing in the opposite direction - boneQuat = -boneQuat; - boneDual = -boneDual; - } - quat += weights[i] * boneQuat; - dual += weights[i] * boneDual; - } - } - - // Normalize dual quaternion - float norm = length(quat); - quat /= norm; - dual /= norm; - vec3 translate = vec3( - 2.0 * (-dual.w * quat.x + dual.x * quat.w - dual.y * quat.z + dual.z * quat.y), - 2.0 * (-dual.w * quat.y + dual.x * quat.z + dual.y * quat.w - dual.z * quat.x), - 2.0 * (-dual.w * quat.z - dual.x * quat.y + dual.y * quat.x + dual.z * quat.w) - ); - - center = quatVec(quat, center) + translate; - quaternion = quatQuat(quat, quaternion); - } -`); - -function applyGsplatSkinning( - gsplat: DynoVal, - skinning: DynoVal, -): DynoVal { - const dyno = new Dyno< - { gsplat: typeof Gsplat; skinning: typeof GsplatSkinning }, - { gsplat: typeof Gsplat } - >({ - inTypes: { gsplat: Gsplat, skinning: GsplatSkinning }, - outTypes: { gsplat: Gsplat }, - globals: () => [defineGsplatSkinning, defineApplyGsplatSkinning], - inputs: { gsplat, skinning }, - statements: ({ inputs, outputs }) => { - const { skinning } = inputs; - const { gsplat } = outputs; - return unindentLines(` - ${gsplat} = ${inputs.gsplat}; - if (isGsplatActive(${gsplat}.flags)) { - applyGsplatSkinning( - ${skinning}.numSplats, ${skinning}.numBones, - ${skinning}.skinTexture, ${skinning}.boneTexture, - ${gsplat}.index, ${gsplat}.center, ${gsplat}.quaternion - ); - } - `); - }, - }); - return dyno.outputs.gsplat; -} diff --git a/src/SplatSorter.ts b/src/SplatSorter.ts new file mode 100644 index 0000000..3893dd2 --- /dev/null +++ b/src/SplatSorter.ts @@ -0,0 +1,425 @@ +import * as THREE from "three"; +import { FullScreenQuad } from "three/addons/postprocessing/Pass.js"; + +import { BatchedSplat } from "./BatchedSplat"; +import type { Splat, SplatData } from "./Splat"; +import { type SplatWorker, allocWorker, withWorkerCall } from "./SplatWorker"; +import { SPLAT_TEX_HEIGHT, SPLAT_TEX_WIDTH } from "./defines"; +import { getShaders } from "./shaders"; +import { getTextureSize } from "./utils"; + +/** + * Specific order of splats. The ordering might not include + * all splats, indicated by the amount of active splats. + */ +export type SplatOrdering = { + /** + * Array of splat indices. + */ + ordering: Uint32Array; + /** + * Number of active splats in this ordering. + */ + activeSplats: number; +}; + +export interface SplatSorter { + sort( + camera: THREE.Camera, + splat: Splat, + renderer: THREE.WebGLRenderer, + ordering: Uint32Array, + ): Promise; +} + +export type ReadbackSorterOptions = { + /** + * Whether to sort splats radially (geometric distance) from the viewpoint (true) + * or by Z-depth (false). Most scenes are trained with the Z-depth sort metric + * and will render more accurately at certain viewpoints. However, radial sorting + * is more stable under viewpoint rotations. + * @default false + */ + sortRadial?: boolean; + /** + * Constant added to Z-depth to bias values into the positive range for + * sortRadial: false, but also used for culling Gsplats "well behind" + * the viewpoint origin + * @default 1.0 + */ + depthBias?: number; + /** + * Set this to true if rendering a 360 to disable "behind the viewpoint" + * culling during sorting. This is set automatically when rendering 360 envMaps + * using the SparkRenderer.renderEnvMap() utility function. + * @default false + */ + sort360?: boolean; + /** + * Set this to true to sort with float32 precision with two-pass sort. + * @default true + */ + sort32?: boolean; +}; + +export class ReadbackSplatSorter implements SplatSorter { + sortRadial: boolean; + depthBias: number; + sort360: boolean; + sort32: boolean; + + private capacity = 0; + private target?: THREE.WebGLArrayRenderTarget; + private readonly readbackBufferPool: ReadbackBufferPool; + + private readonly material: THREE.RawShaderMaterial; + + constructor(options: ReadbackSorterOptions = {}) { + this.sortRadial = options.sortRadial ?? false; + this.depthBias = options.depthBias ?? 1.0; + this.sort360 = options.sort360 ?? false; + this.sort32 = options.sort32 ?? false; + + this.readbackBufferPool = new ReadbackBufferPool(); + this.material = ReadbackSplatSorter.createMaterial(); + } + + async sort( + camera: THREE.Camera, + splat: Splat, + renderer: THREE.WebGLRenderer, + ordering: Uint32Array, + ): Promise { + // Read the depth for each splat + const splatData = splat.splatData; + const maxSplats = splatData.maxSplats; + const numSplats = splatData.numSplats; + + // Render + const count = this.sort32 ? numSplats : numSplats / 2; + const readbackBuffer = await this.ensureCapacity(count); + const readback = this.sort32 + ? readbackBuffer.readback32 + : readbackBuffer.readback16; + + const renderState = this.saveRenderState(renderer); + + const material = this.material; + splatData.setupMaterial(material); + material.uniforms.sortRadial.value = this.sortRadial; + material.uniforms.sortDepthBias.value = this.depthBias; + material.uniforms.sort360.value = this.sort360; + material.uniforms.splatModelViewMatrix.value.multiplyMatrices( + camera.matrixWorldInverse, + splat.matrixWorld, + ); + material.defines.SORT32 = this.sort32; + + this.render(renderer, count, material); + const promise = this.read(renderer, count, readback); + + this.resetRenderState(renderer, renderState); + + await promise; + + // Perform sorting + const rpcName = this.sort32 ? "sort32Splats" : "sortDoubleSplats"; + const result = await withWorkerCall(rpcName, { + maxSplats, + numSplats, + readback: readback as Uint16Array, // FIXME: type depends on RPC method + ordering, + }); + + // Restore transferred array readback buffers + if (result.readback instanceof Uint16Array) { + readbackBuffer.readback16 = result.readback; + } else { + readbackBuffer.readback32 = result.readback; + } + readbackBuffer.buffer = result.readback.buffer; + this.readbackBufferPool.free(readbackBuffer); + + return { ordering: result.ordering, activeSplats: result.activeSplats }; + } + + dispose() { + if (this.target) { + this.target.dispose(); + this.target = undefined; + } + } + + // Ensure our render target is large enough for the readback of capacity indices. + private async ensureCapacity(capacity: number): Promise { + const { width, height, depth, maxSplats } = getTextureSize(capacity); + if (!this.target || maxSplats > this.capacity) { + this.dispose(); + this.capacity = maxSplats; + + // The only portable readback format for WebGL2 is RGBA8 + this.target = new THREE.WebGLArrayRenderTarget(width, height, depth, { + depthBuffer: false, + stencilBuffer: false, + generateMipmaps: false, + magFilter: THREE.NearestFilter, + minFilter: THREE.NearestFilter, + }); + this.target.texture.format = THREE.RGBAFormat; + this.target.texture.type = THREE.UnsignedByteType; + this.target.texture.internalFormat = "RGBA8"; + this.target.scissorTest = true; + } + + const byteLength = this.target.width * this.target.height * 4; + const readbackBuffer = await this.readbackBufferPool.alloc(byteLength); + return readbackBuffer; + } + + private saveRenderState(renderer: THREE.WebGLRenderer) { + return { + currentRenderTarget: renderer.getRenderTarget(), + xrEnabled: renderer.xr.enabled, + autoClear: renderer.autoClear, + }; + } + + private resetRenderState( + renderer: THREE.WebGLRenderer, + state: { + currentRenderTarget: THREE.WebGLRenderTarget | null; + xrEnabled: boolean; + autoClear: boolean; + }, + ) { + renderer.setRenderTarget(state.currentRenderTarget); + renderer.xr.enabled = state.xrEnabled; + renderer.autoClear = state.autoClear; + } + + private render( + renderer: THREE.WebGLRenderer, + count: number, + material: THREE.RawShaderMaterial, + ) { + if (!this.target) { + throw new Error("No target"); + } + + ReadbackSplatSorter.fullScreenQuad.material = material; + + // Run the program in "layer" chunks, in horizontal row ranges, + // that cover the total count of indices. + const layerSize = SPLAT_TEX_WIDTH * SPLAT_TEX_HEIGHT; + material.uniforms.targetBase.value = 0; + material.uniforms.targetCount.value = count; + let baseIndex = 0; + + // Keep generating layers until completed count items + while (baseIndex < count) { + const layer = Math.floor(baseIndex / layerSize); + const layerBase = layer * layerSize; + const layerYEnd = Math.min( + SPLAT_TEX_HEIGHT, + Math.ceil((count - layerBase) / SPLAT_TEX_WIDTH), + ); + material.uniforms.targetLayer.value = layer; + + // Render the desired portion of the layer + this.target.scissor.set(0, 0, SPLAT_TEX_WIDTH, layerYEnd); + renderer.setRenderTarget(this.target, layer); + renderer.xr.enabled = false; + renderer.autoClear = false; + ReadbackSplatSorter.fullScreenQuad.render(renderer); + + baseIndex += SPLAT_TEX_WIDTH * layerYEnd; + } + } + + private async read( + renderer: THREE.WebGLRenderer, + count: number, + readback: B, + ): Promise { + if (!renderer) { + throw new Error("No renderer"); + } + if (!this.target) { + throw new Error("No target"); + } + + const roundedCount = Math.ceil(count / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; + if (readback.byteLength < roundedCount * 4) { + throw new Error( + `Readback buffer too small: ${readback.byteLength} < ${roundedCount * 4}`, + ); + } + const readbackUint8 = new Uint8Array( + readback instanceof ArrayBuffer ? readback : readback.buffer, + ); + + // We can only read back one 2D array layer of pixels at a time, + // so loop through them, initiate the readback, and collect the + // completion promises. + + const layerSize = SPLAT_TEX_WIDTH * SPLAT_TEX_HEIGHT; + let baseIndex = 0; + const promises = []; + + while (baseIndex < count) { + const layer = Math.floor(baseIndex / layerSize); + const layerBase = layer * layerSize; + const layerYEnd = Math.min( + SPLAT_TEX_HEIGHT, + Math.ceil((count - layerBase) / SPLAT_TEX_WIDTH), + ); + + renderer.setRenderTarget(this.target, layer); + + // Compute the subarray that this layer of readback corresponds to + const readbackSize = SPLAT_TEX_WIDTH * layerYEnd * 4; + const subReadback = readbackUint8.subarray( + layerBase * 4, + layerBase * 4 + readbackSize, + ); + const promise = renderer?.readRenderTargetPixelsAsync( + this.target, + 0, + 0, + SPLAT_TEX_WIDTH, + layerYEnd, + subReadback, + ); + promises.push(promise); + + baseIndex += SPLAT_TEX_WIDTH * layerYEnd; + } + return Promise.all(promises).then(() => readback); + } + + private static createMaterial() { + const shaders = getShaders(); + const material = new THREE.RawShaderMaterial({ + name: "SplatDistanceShader", + glslVersion: "300 es", + uniforms: { + targetBase: { value: 0 }, + targetCount: { value: 0 }, + targetLayer: { value: 0 }, + // Note: this modelViewMatrix is named differently to avoid Three.js + // populating it with the MVP of the fullscreen quad. + splatModelViewMatrix: { value: new THREE.Matrix4() }, + sortRadial: { value: false }, + sortDepthBias: { value: 1.0 }, + sort360: { value: false }, + }, + vertexShader: shaders.identityVertex, + fragmentShader: shaders.splatDistanceFragment, + }); + + return material; + } + + static fullScreenQuad = new FullScreenQuad( + new THREE.RawShaderMaterial({ visible: false }), + ); +} + +type ReadbackBuffer = { + buffer: ArrayBuffer; + readback16: Uint16Array; + readback32: Uint32Array; +}; + +export class ReadbackBufferPool { + private items: Array = []; + + async alloc(byteLength: number): Promise { + const item = this.allocInternal(); + if (item.buffer.byteLength < byteLength) { + item.buffer = item.buffer.transfer(byteLength); + item.readback16 = new Uint16Array(item.buffer); + item.readback32 = new Uint32Array(item.buffer); + } + return item; + } + + free(item: ReadbackBuffer) { + this.items.push(item); + } + + private allocInternal() { + const item = this.items.pop(); + if (item) { + return item; + } + const buffer = new ArrayBuffer(); + return { + buffer, + readback16: new Uint16Array(buffer), + readback32: new Uint32Array(buffer), + }; + } +} + +const tempV3 = new THREE.Vector3(); +const tempMatrix = new THREE.Matrix4(); + +/** + * CPU based sorting solution that supports rigid transforms of splats. + */ +export class CpuSplatSorter implements SplatSorter { + /** + * Private splat worker, kept around as local worker memory is used to retain splat center data. + */ + private worker?: SplatWorker; + private workerPromise: Promise; + + private centersUploaded = false; + + constructor() { + this.workerPromise = allocWorker().then((worker) => { + this.worker = worker; + }); + } + + async sort( + camera: THREE.Camera, + splat: Splat, + renderer: THREE.WebGLRenderer, + ordering: Uint32Array, + ): Promise { + await this.workerPromise; + if (!this.worker) { + throw new Error("Unreachable"); + } + + let splatCenters: Float32Array | undefined = undefined; + if (!this.centersUploaded) { + const centers = new Float32Array(splat.splatData.numSplats * 3); + splat.splatData.iterateCenters((i, x, y, z) => { + centers[i * 3 + 0] = x; + centers[i * 3 + 1] = y; + centers[i * 3 + 2] = z; + }); + splatCenters = centers; + this.centersUploaded = true; + } + + tempMatrix.copy(camera.matrixWorld); + const viewOrigin = camera.getWorldPosition(tempV3).toArray(); + const viewDir = tempV3 + .set(0, 0, -1) + .transformDirection(tempMatrix) + .toArray(); + + const result = this.worker.call("sortSplatsCpu", { + centers: splatCenters, + transforms: splat.getTransformRanges(), + viewOrigin, + viewDir, + ordering, + }); + return result; + } +} diff --git a/src/SplatUtils.ts b/src/SplatUtils.ts new file mode 100644 index 0000000..f079489 --- /dev/null +++ b/src/SplatUtils.ts @@ -0,0 +1,116 @@ +import * as THREE from "three"; +import { type IterableSplatData, Splat, type SplatData } from "./Splat"; +import { DefaultSplatEncoding, type SplatEncoder } from "./encoding/encoder"; + +const tempCenter = new THREE.Vector3(); +const tempScales = new THREE.Vector3(); +const tempQuat = new THREE.Quaternion(); + +/** + * Combines multiple Splat objects into a single Splat object. The individual + * world transforms are applied to the individual splats. Each splat must + * have the same number of spherical harmonics and the underlying SplatData + * must be iterable. + * @param splats The splats to combine + * @param options Additional options + * @returns The combined splat + */ +export function mergeSplats( + splats: Array, + options?: { + splatEncoder?: SplatEncoder | (() => SplatEncoder); + }, +): Splat | null { + const numSh = splats[0].splatData.numSh; + const splatEncoderFactory = + options?.splatEncoder ?? DefaultSplatEncoding.createSplatEncoder; + const splatEncoder = + typeof splatEncoderFactory === "function" + ? splatEncoderFactory() + : splatEncoderFactory; + + // Sum the total amount of combined splats. + const numSplats = splats.reduce( + (acc, splat) => acc + splat.splatData.numSplats, + 0, + ); + splatEncoder.allocate(numSplats, numSh); + + let newSplatIndex = 0; + for (let i = 0; i < splats.length; ++i) { + const splatData = splats[i].splatData; + if (splatData.numSh !== numSh) { + console.error( + `SplatUtils: .mergeSplats() failed with splat at index ${i}. All splats must have the same amount of spherical harmonics.`, + ); + return null; + } + + if (!isIterableSplatData(splatData)) { + console.error( + `SplatUtils: .mergeSplats() failed with splat at index ${i}. All splats must have iterable splat data.`, + ); + return null; + } + + // Ensure matrix world is up to date + splats[i].updateMatrixWorld(); + const splatScale = splats[i].getWorldScale(new THREE.Vector3()); + const splatRotation = splats[i].getWorldQuaternion(new THREE.Quaternion()); + + splatData.iterateSplats( + ( + _, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + sh, + ) => { + // Apply splat transform + tempCenter.set(x, y, z).applyMatrix4(splats[i].matrixWorld); + tempScales.set(scaleX, scaleY, scaleZ).multiplyScalar(splatScale.x); // Assume uniform scaling + tempQuat.set(quatX, quatY, quatZ, quatW).premultiply(splatRotation); + + splatEncoder.setSplat( + newSplatIndex++, + tempCenter.x, + tempCenter.y, + tempCenter.z, + tempScales.x, + tempScales.y, + tempScales.z, + tempQuat.x, + tempQuat.y, + tempQuat.z, + tempQuat.w, + opacity, + r, + g, + b, + ); + if (sh) { + splatEncoder.setSplatSh(newSplatIndex, sh); + } + }, + ); + } + + return new Splat(splatEncoder.close()); +} + +export function isIterableSplatData( + splatData: SplatData, +): splatData is IterableSplatData { + return "iterateSplats" in splatData; +} diff --git a/src/SplatWorker.ts b/src/SplatWorker.ts new file mode 100644 index 0000000..e4786be --- /dev/null +++ b/src/SplatWorker.ts @@ -0,0 +1,155 @@ +import { getArrayBuffers } from "./utils.js"; +import type { RpcMethods } from "./worker/worker.js"; +import BundledWorker from "./worker/worker.js?worker&inline"; + +/** + * SplatWorker is an internal class that manages a WebWorker for executing + * longer running CPU tasks such as Gsplat file decoding and sorting. + * Although a SplatWorker can be created and used directly, the utility + * function withWorker() is recommended to allocate from a managed + * pool of SplatWorkers. + */ +export class SplatWorker { + private worker: Worker; + private messages: Record< + number, + { resolve: (value: unknown) => void; reject: (reason?: unknown) => void } + > = {}; + private messageIdNext = 0; + + constructor() { + this.worker = new BundledWorker(); + this.worker.onmessage = (event) => this.onMessage(event); + } + + private makeMessageId(): number { + return ++this.messageIdNext; + } + + private makeMessagePromiseId(): { id: number; promise: Promise } { + const id = this.makeMessageId(); + const promise = new Promise((resolve, reject) => { + this.messages[id] = { resolve, reject }; + }); + return { id, promise }; + } + + private onMessage(event: MessageEvent) { + const { id, result, error } = event.data; + const handler = this.messages[id]; + if (handler) { + delete this.messages[id]; + if (error) { + handler.reject(error); + } else { + handler.resolve(result); + } + } + } + + /** + * Invoke an RPC on the worker with the given name and arguments. + * The normal usage of a worker is to run one activity at a time, + * but this function allows for concurrent calls, tagging each request + * with a unique message Id and awaiting a response to that same Id. + * The method will automatically transfer any ArrayBuffers in the + * arguments to the worker. If you'd like to transfer a copy of a + * buffer then you must clone it before passing to this function. + * + * @param name Name of the RPC call + * @param args + */ + async call( + name: Method, + args: RpcMethods[Method]["args"], + ): Promise { + const { id, promise } = this.makeMessagePromiseId(); + this.worker.postMessage( + { name, args, id }, + { transfer: getArrayBuffers(args) }, + ); + return promise as Promise; + } +} + +let maxWorkers = 4; + +let numWorkers = 0; +const freeWorkers: SplatWorker[] = []; +const workerQueue: ((worker: SplatWorker) => void)[] = []; + +/** + * Set the maximum number of workers to allocate for the pool. + * @param count Number of workers (default: 4) + */ +export function setWorkerPool(count: number) { + maxWorkers = count; +} + +/** + * Allocate a worker from the pool. If none are available and we are below the + * maximum, create a new one. Otherwise, add the request to a queue and wait + * for it to be fulfilled. + * @returns + */ +export async function allocWorker(): Promise { + const worker = freeWorkers.shift(); + if (worker) { + return worker; + } + + if (numWorkers < maxWorkers) { + const worker = new SplatWorker(); + numWorkers += 1; + return worker; + } + + return new Promise((resolve) => { + workerQueue.push(resolve); + }); +} + +/** + * Return a worker to the pool. Pass the worker to any pending waiter. + * @param worker The worker to return + */ +function freeWorker(worker: SplatWorker) { + if (numWorkers > maxWorkers) { + // Worker no longer needed + numWorkers -= 1; + return; + } + + const waiter = workerQueue.shift(); + if (waiter) { + waiter(worker); + return; + } + + freeWorkers.push(worker); +} + +/** + * Allocate a worker from the pool and invoke the callback with the worker. + * In case the worker is used for a single RPC, consider using the withWorkerCall + * shorthand. + * @param callback The callback to call + * @returns Promise that resolves when the callback completes + */ +export async function withWorker( + callback: (worker: SplatWorker) => Promise, +): Promise { + const worker = await allocWorker(); + try { + return await callback(worker); + } finally { + freeWorker(worker); + } +} + +export async function withWorkerCall( + name: Method, + args: RpcMethods[Method]["args"], +): Promise { + return await withWorker((worker) => worker.call(name, args)); +} diff --git a/src/antisplat.ts b/src/antisplat.ts deleted file mode 100644 index 4943d56..0000000 --- a/src/antisplat.ts +++ /dev/null @@ -1,125 +0,0 @@ -import type { SplatEncoding } from "./PackedSplats"; -import { computeMaxSplats, setPackedSplat } from "./utils"; - -export function decodeAntiSplat( - fileBytes: Uint8Array, - initNumSplats: (numSplats: number) => void, - splatCallback: ( - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, - ) => void, -) { - const numSplats = Math.floor(fileBytes.length / 32); // 32 bytes per splat - if (numSplats * 32 !== fileBytes.length) { - throw new Error("Invalid .splat file size"); - } - initNumSplats(numSplats); - - const f32 = new Float32Array(fileBytes.buffer); - for (let i = 0; i < numSplats; ++i) { - const i32 = i * 32; - const i8 = i * 8; - const x = f32[i8 + 0]; - const y = f32[i8 + 1]; - const z = f32[i8 + 2]; - const scaleX = f32[i8 + 3]; - const scaleY = f32[i8 + 4]; - const scaleZ = f32[i8 + 5]; - const r = fileBytes[i32 + 24] / 255; - const g = fileBytes[i32 + 25] / 255; - const b = fileBytes[i32 + 26] / 255; - const opacity = fileBytes[i32 + 27] / 255; - const quatW = (fileBytes[i32 + 28] - 128) / 128; - const quatX = (fileBytes[i32 + 29] - 128) / 128; - const quatY = (fileBytes[i32 + 30] - 128) / 128; - const quatZ = (fileBytes[i32 + 31] - 128) / 128; - splatCallback( - i, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ); - } -} - -export function unpackAntiSplat( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): { - packedArray: Uint32Array; - numSplats: number; -} { - let numSplats = 0; - let maxSplats = 0; - let packedArray = new Uint32Array(0); - decodeAntiSplat( - fileBytes, - (cbNumSplats) => { - numSplats = cbNumSplats; - maxSplats = computeMaxSplats(numSplats); - packedArray = new Uint32Array(maxSplats * 4); - }, - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - setPackedSplat( - packedArray, - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - splatEncoding, - ); - }, - ); - return { packedArray, numSplats }; -} diff --git a/src/defines.ts b/src/defines.ts index feaa44e..204a91f 100644 --- a/src/defines.ts +++ b/src/defines.ts @@ -12,6 +12,21 @@ export const SCALE_MAX = Math.exp(LN_SCALE_MAX); export const LN_SCALE_ZERO = -30.0; export const SCALE_ZERO = Math.exp(LN_SCALE_ZERO); +export const SH_C0 = 0.28209479177387814; + +export const NUM_COEFF_TO_SH_DEGREE: Record = { + 0: 0, + 9: 1, + 24: 2, + 45: 3, +}; +export const SH_DEGREE_TO_NUM_COEFF: Record = { + 0: 0, + 1: 9, + 2: 24, + 3: 45, +}; + // Gsplats are stored in textures that are 2^11 x 2^11 x up to 2^11 // Most WebGL2 implementations support 2D textures up to 2^12 x 2^12 (max 16M Gsplats) // 2D array textures and 3D textures up to 2^11 x 2^11 x 2^11 (max 8G Gsplats), @@ -41,3 +56,9 @@ export const WASM_SPLAT_SORT = true; // in the plyReader. export const USE_COMPILED_PARSER_FUNCTION = true; + +export type TransformRange = { + start: number; + end: number; + matrix: number[]; +}; diff --git a/src/dyno.ts b/src/dyno.ts deleted file mode 100644 index 13d499b..0000000 --- a/src/dyno.ts +++ /dev/null @@ -1,16 +0,0 @@ -export * from "./dyno/types"; -export * from "./dyno/base"; -export * from "./dyno/value"; -export * from "./dyno/output"; -export * from "./dyno/uniforms"; -export * from "./dyno/program"; -export * from "./dyno/math"; -export * from "./dyno/logic"; -export * from "./dyno/util"; -export * from "./dyno/splats"; -export * from "./dyno/transform"; -export * from "./dyno/control"; -export * from "./dyno/convert"; -export * from "./dyno/texture"; -export * from "./dyno/trig"; -export * from "./dyno/vecmat"; diff --git a/src/dyno/base.ts b/src/dyno/base.ts deleted file mode 100644 index 9e48e27..0000000 --- a/src/dyno/base.ts +++ /dev/null @@ -1,575 +0,0 @@ -import type { IUniform } from "three"; -import type { DynoType } from "./types"; -import { - DynoLiteral, - DynoOutput, - type DynoVal, - DynoValue, - type HasDynoOut, - valType, -} from "./value"; - -const DEFAULT_INDENT = " "; - -export class Compilation { - globals: Set = new Set(); - statements: string[] = []; - uniforms: Record = {}; - declares: Set = new Set(); - updaters: (() => void)[] = []; - sequence = 0; - indent: string = DEFAULT_INDENT; - - constructor({ indent }: { indent?: string } = {}) { - this.indent = indent ?? DEFAULT_INDENT; - } - - nextSequence() { - return this.sequence++; - } -} - -export type IOTypes = Record; -type GenerateContext = { - inputs: { [K in keyof InTypes]?: string }; - outputs: { [K in keyof OutTypes]?: string }; - compile: Compilation; -}; - -export class Dyno { - inTypes: InTypes; - outTypes: OutTypes; - - inputs: { [K in keyof InTypes]?: DynoVal }; - update?: () => void; - globals?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - statements?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - generate: ({ - inputs, - outputs, - compile, - }: GenerateContext) => { - globals?: string[]; - statements?: string[]; - uniforms?: Record; - }; - - constructor({ - inTypes, - outTypes, - inputs, - update, - globals, - statements, - generate, - }: { - inTypes?: InTypes; - outTypes?: OutTypes; - inputs?: { [K in keyof InTypes]?: DynoVal }; - update?: () => void; - globals?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - statements?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - generate?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => { - globals?: string[]; - statements?: string[]; - uniforms?: Record; - }; - }) { - this.inTypes = inTypes ?? ({} as InTypes); - this.outTypes = outTypes ?? ({} as OutTypes); - this.inputs = inputs ?? {}; - this.update = update; - - this.globals = globals; - this.statements = statements; - this.generate = - generate ?? - (({ inputs, outputs, compile }) => { - return { - globals: this.globals?.({ inputs, outputs, compile }), - statements: this.statements?.({ inputs, outputs, compile }), - }; - }); - } - - get outputs(): { [K in keyof OutTypes]: DynoVal } { - const outputs = {} as { [K in keyof OutTypes]: DynoVal }; - for (const key in this.outTypes) { - outputs[key] = new DynoOutput(this, key); - } - return outputs; - } - - apply(inputs: { [K in keyof InTypes]?: DynoVal }): { - [K in keyof OutTypes]: DynoVal; - } { - Object.assign(this.inputs, inputs); - return this.outputs; - } - - compile({ - inputs, - outputs, - compile, - }: { - inputs: { [K in keyof InTypes]?: string }; - outputs: { [K in keyof OutTypes]?: string }; - compile: Compilation; - }): string[] { - const result = [ - `// ${this.constructor.name}(${Object.values(inputs).join(", ")}) => (${Object.values(outputs).join(", ")})`, - ]; - - const declares: (keyof OutTypes)[] = []; - for (const key in outputs) { - const name = outputs[key]; - if (name && !compile.declares.has(name)) { - compile.declares.add(name); - declares.push(key); - } - } - - const { globals, statements, uniforms } = this.generate({ - inputs, - outputs, - compile, - }); - for (const global of globals ?? []) { - compile.globals.add(global); - } - for (const key in uniforms) { - compile.uniforms[key] = uniforms[key]; - } - if (this.update) { - compile.updaters.push(this.update); - } - - for (const key of declares) { - const name = outputs[key]; - if (name) { - if (!compile.uniforms[name]) { - result.push(`${dynoDeclare(name, this.outTypes[key])};`); - } - } - } - - if (statements?.length) { - result.push("{"); - result.push(...statements.map((line) => compile.indent + line)); - result.push("}"); - } - return result; - } -} - -export type DynoBlockType = ( - inputs: { [K in keyof InTypes]?: DynoVal }, - outputs: { [K in keyof OutTypes]?: DynoVal }, - { roots }: { roots: Dyno[] }, -) => { [K in keyof OutTypes]?: DynoVal } | undefined; - -export class DynoBlock< - InTypes extends IOTypes, - OutTypes extends IOTypes, -> extends Dyno { - construct: DynoBlockType; - - constructor({ - inTypes, - outTypes, - inputs, - update, - globals, - construct, - }: { - inTypes?: InTypes; - outTypes?: OutTypes; - inputs?: { [K in keyof InTypes]?: DynoVal }; - update?: () => void; - globals?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - construct: DynoBlockType; - }) { - super({ - inTypes, - outTypes, - inputs, - update, - globals, - generate: (args) => this.generateBlock(args), - }); - this.construct = construct; - } - - generateBlock({ - inputs, - outputs, - compile, - }: { - inputs: { [K in keyof InTypes]?: string }; - outputs: { [K in keyof OutTypes]?: string }; - compile: Compilation; - }) { - const blockInputs: { [K in keyof InTypes]?: DynoVal } = {}; - const blockOutputs: { [K in keyof OutTypes]?: DynoVal } = {}; - - for (const key in inputs) { - if (inputs[key] != null) { - blockInputs[key] = new DynoLiteral(this.inTypes[key], inputs[key]); - } - } - for (const key in outputs) { - if (outputs[key] != null) { - blockOutputs[key] = new DynoValue(this.outTypes[key]); - } - } - - const options = { roots: [] }; - const returned = this.construct(blockInputs, blockOutputs, options); - - for (const global of this.globals?.({ inputs, outputs, compile }) ?? []) { - compile.globals.add(global); - } - - const ordering: Dyno[] = []; - const nodeOuts = new Map< - Dyno, - { sequence: number; outNames: Map; newOuts: Set } - >(); - - function visit( - node: Dyno, - outKey?: string, - outName?: string, - ) { - let outs = nodeOuts.get(node); - if (!outs) { - // First time visiting this node - outs = { - sequence: compile.nextSequence(), - outNames: new Map(), - newOuts: new Set(), - }; - nodeOuts.set(node, outs); - - for (const key in node.inputs) { - let input = node.inputs[key]; - while (input) { - if (input instanceof DynoValue) { - if (input instanceof DynoOutput) { - visit(input.dyno, input.key); - } - break; - } - // Must be as HasDynoOut - input = input.dynoOut(); - } - } - ordering.push(node); - } - if (outKey) { - if (!outName) { - outs.newOuts.add(outKey); - } - outs.outNames.set(outKey, outName ?? `${outKey}_${outs.sequence}`); - } - } - - for (const root of options.roots) { - visit(root); - } - - for (const key in blockOutputs) { - let value = returned?.[key] ?? blockOutputs[key]; - while (value) { - if (value instanceof DynoValue) { - if (value instanceof DynoOutput) { - visit(value.dyno, value.key, outputs[key]); - } - break; - } - // Must be as HasDynoOut - value = value.dynoOut(); - } - blockOutputs[key] = value; - } - - const steps = []; - - for (const dyno of ordering) { - // compile.statements.push(`// ${dyno.constructor.name}(${Object.values(inputs).join(", ")}) => (${Object.values(outputs).join(", ")})`); - - const inputs: Record = {}; - const outputs: Record = {}; - - for (const key in dyno.inputs) { - let value = dyno.inputs[key]; - while (value) { - if (value instanceof DynoValue) { - if (value instanceof DynoLiteral) { - inputs[key] = value.getLiteral(); - } else if (value instanceof DynoOutput) { - const source = nodeOuts.get(value.dyno)?.outNames.get(value.key); - if (!source) { - throw new Error( - `Source not found for ${value.dyno.constructor.name}.${value.key}`, - ); - } - inputs[key] = source; - } - break; - } - // Must be as HasDynOut - value = value.dynoOut(); - } - } - - const outs = nodeOuts.get(dyno) ?? { outNames: new Map() }; - for (const [key, name] of outs.outNames.entries()) { - outputs[key] = name; - } - - const newSteps = dyno.compile({ inputs, outputs, compile }); - steps.push(newSteps); - } - - const literalOutputs = []; - for (const key in outputs) { - if (blockOutputs[key] instanceof DynoLiteral) { - literalOutputs.push( - `${outputs[key]} = ${blockOutputs[key].getLiteral()};`, - ); - } - } - if (literalOutputs.length > 0) { - steps.push(literalOutputs); - } - - const statements = steps.flatMap((step, index) => { - // Add a blank line between steps - return index === 0 ? step : ["", ...step]; - }); - return { statements }; - } -} - -export function dynoBlock< - InTypes extends Record, - OutTypes extends Record, ->( - inTypes: InTypes, - outTypes: OutTypes, - construct: DynoBlockType, - { update, globals }: { update?: () => void; globals?: () => string[] } = {}, -) { - return new DynoBlock({ inTypes, outTypes, construct, update, globals }); -} - -export function dyno< - InTypes extends Record, - OutTypes extends Record, ->({ - inTypes, - outTypes, - inputs, - update, - globals, - statements, - generate, -}: { - inTypes: InTypes; - outTypes: OutTypes; - inputs?: { [K in keyof InTypes]?: DynoVal }; - update?: () => void; - globals?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - statements?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => string[]; - generate?: ({ - inputs, - outputs, - compile, - }: GenerateContext) => { - globals?: string[]; - statements?: string[]; - uniforms?: Record; - }; -}) { - return new Dyno({ - inTypes, - outTypes, - inputs, - update, - globals, - statements, - generate, - }); -} - -export function dynoDeclare(name: string, type: DynoType, count?: number) { - const typeStr = typeof type === "string" ? type : type.type; - if (!typeStr) { - throw new Error(`Invalid DynoType: ${String(type)}`); - } - return `${typeStr} ${name}${count != null ? `[${count}]` : ""}`; -} - -export function unindentLines(s: string): string[] { - let seenNonEmpty = false; - const lines = s - .split("\n") - .map((line) => { - const trimmedLine = line.trimEnd(); - if (seenNonEmpty) { - return trimmedLine; - } - if (trimmedLine.length > 0) { - seenNonEmpty = true; - return trimmedLine; - } - return null; - }) - .filter((line) => line != null); - while (lines.length > 0 && lines[lines.length - 1].length === 0) { - lines.pop(); - } - if (lines.length === 0) { - return []; - } - - const indent = lines[0].match(/^\s*/)?.[0]; - if (!indent) { - return lines; // No indent, return as is - } - // Remove indent from the beginning of each line - const regex = new RegExp(`^${indent}`); - return lines.map((line) => line.replace(regex, "")); -} - -export function unindent(s: string): string { - return unindentLines(s).join("\n"); -} - -export class UnaryOp< - A extends DynoType, - OutType extends DynoType, - OutKey extends string, - > - extends Dyno<{ a: A }, { [key in OutKey]: OutType }> - implements HasDynoOut -{ - constructor({ - a, - outKey, - outTypeFunc, - }: { a: DynoVal; outKey: OutKey; outTypeFunc: (aType: A) => OutType }) { - const inTypes = { a: valType(a) }; - const outType = outTypeFunc(valType(a)); - const outTypes = { [outKey]: outType } as { [key in OutKey]: OutType }; - super({ inTypes, outTypes, inputs: { a } }); - this.outKey = outKey; - } - - outKey: OutKey; - dynoOut(): DynoValue { - return new DynoOutput(this, this.outKey); - } -} - -export class BinaryOp< - A extends DynoType, - B extends DynoType, - OutType extends DynoType, - OutKey extends string, - > - extends Dyno<{ a: A; b: B }, { [key in OutKey]: OutType }> - implements HasDynoOut -{ - constructor({ - a, - b, - outKey, - outTypeFunc, - }: { - a: DynoVal; - b: DynoVal; - outKey: OutKey; - outTypeFunc: (aType: A, bType: B) => OutType; - }) { - const inTypes = { a: valType(a), b: valType(b) }; - const outType = outTypeFunc(valType(a), valType(b)); - const outTypes = { [outKey]: outType } as { [key in OutKey]: OutType }; - super({ inTypes, outTypes, inputs: { a, b } }); - this.outKey = outKey; - } - - outKey: OutKey; - dynoOut(): DynoValue { - return new DynoOutput(this, this.outKey); - } -} - -export class TrinaryOp< - A extends DynoType, - B extends DynoType, - C extends DynoType, - OutType extends DynoType, - OutKey extends string, - > - extends Dyno<{ a: A; b: B; c: C }, { [key in OutKey]: OutType }> - implements HasDynoOut -{ - constructor({ - a, - b, - c, - outKey, - outTypeFunc, - }: { - a: DynoVal; - b: DynoVal; - c: DynoVal; - outKey: OutKey; - outTypeFunc: (aType: A, bType: B, cType: C) => OutType; - }) { - const inTypes = { a: valType(a), b: valType(b), c: valType(c) }; - const outType = outTypeFunc(valType(a), valType(b), valType(c)); - const outTypes = { [outKey]: outType } as { [key in OutKey]: OutType }; - super({ inTypes, outTypes, inputs: { a, b, c } }); - this.outKey = outKey; - } - - outKey: OutKey; - dynoOut(): DynoValue { - return new DynoOutput(this, this.outKey); - } -} diff --git a/src/dyno/control.ts b/src/dyno/control.ts deleted file mode 100644 index 7835f6b..0000000 --- a/src/dyno/control.ts +++ /dev/null @@ -1,22 +0,0 @@ -// TODO: -// if, switch, for, comment, -// arrayIndex, arrayLength, - -export const dynoIf = () => { - throw new Error("Not implemented"); -}; -export const dynoSwitch = () => { - throw new Error("Not implemented"); -}; -export const dynoFor = () => { - throw new Error("Not implemented"); -}; -export const comment = () => { - throw new Error("Not implemented"); -}; -export const arrayIndex = () => { - throw new Error("Not implemented"); -}; -export const arrayLength = () => { - throw new Error("Not implemented"); -}; diff --git a/src/dyno/convert.ts b/src/dyno/convert.ts deleted file mode 100644 index 9837431..0000000 --- a/src/dyno/convert.ts +++ /dev/null @@ -1,451 +0,0 @@ -import { UnaryOp } from "./base"; -import { type SimpleTypes, typeLiteral } from "./types"; -import type { DynoVal } from "./value"; - -export const bool = ( - value: DynoVal, -): DynoVal<"bool"> => new Bool({ value }); -export const int = ( - value: DynoVal, -): DynoVal<"int"> => new Int({ value }); -export const uint = ( - value: DynoVal, -): DynoVal<"uint"> => new Uint({ value }); -export const float = ( - value: DynoVal, -): DynoVal<"float"> => new Float({ value }); - -export const bvec2 = ( - value: DynoVal, -): DynoVal<"bvec2"> => new BVec2({ value }); -export const bvec3 = ( - value: DynoVal, -): DynoVal<"bvec3"> => new BVec3({ value }); -export const bvec4 = ( - value: DynoVal, -): DynoVal<"bvec4"> => new BVec4({ value }); - -export const ivec2 = ( - value: DynoVal, -): DynoVal<"ivec2"> => new IVec2({ value }); -export const ivec3 = ( - value: DynoVal, -): DynoVal<"ivec3"> => new IVec3({ value }); -export const ivec4 = ( - value: DynoVal, -): DynoVal<"ivec4"> => new IVec4({ value }); - -export const uvec2 = ( - value: DynoVal, -): DynoVal<"uvec2"> => new UVec2({ value }); -export const uvec3 = ( - value: DynoVal, -): DynoVal<"uvec3"> => new UVec3({ value }); -export const uvec4 = ( - value: DynoVal, -): DynoVal<"uvec4"> => new UVec4({ value }); - -export const vec2 = < - T extends "float" | "bvec2" | "ivec2" | "uvec2" | "vec2" | "vec3" | "vec4", ->( - value: DynoVal, -): DynoVal<"vec2"> => new Vec2({ value }); -export const vec3 = < - T extends "float" | "bvec3" | "ivec3" | "uvec3" | "vec3" | "vec4", ->( - value: DynoVal, -): DynoVal<"vec3"> => new Vec3({ value }); -export const vec4 = ( - value: DynoVal, -): DynoVal<"vec4"> => new Vec4({ value }); - -export const mat2 = ( - value: DynoVal, -): DynoVal<"mat2"> => new Mat2({ value }); -export const mat3 = ( - value: DynoVal, -): DynoVal<"mat3"> => new Mat3({ value }); -export const mat4 = ( - value: DynoVal, -): DynoVal<"mat4"> => new Mat4({ value }); - -export const floatBitsToInt = (value: DynoVal<"float">): DynoVal<"int"> => - new FloatBitsToInt({ value }); -export const floatBitsToUint = (value: DynoVal<"float">): DynoVal<"uint"> => - new FloatBitsToUint({ value }); -export const intBitsToFloat = (value: DynoVal<"int">): DynoVal<"float"> => - new IntBitsToFloat({ value }); -export const uintBitsToFloat = (value: DynoVal<"uint">): DynoVal<"float"> => - new UintBitsToFloat({ value }); - -export const packSnorm2x16 = (value: DynoVal<"vec2">): DynoVal<"uint"> => - new PackSnorm2x16({ value }); -export const unpackSnorm2x16 = (value: DynoVal<"uint">): DynoVal<"vec2"> => - new UnpackSnorm2x16({ value }); -export const packUnorm2x16 = (value: DynoVal<"vec2">): DynoVal<"uint"> => - new PackUnorm2x16({ value }); -export const unpackUnorm2x16 = (value: DynoVal<"uint">): DynoVal<"vec2"> => - new UnpackUnorm2x16({ value }); - -export const packHalf2x16 = (value: DynoVal<"vec2">): DynoVal<"uint"> => - new PackHalf2x16({ value }); -export const unpackHalf2x16 = (value: DynoVal<"uint">): DynoVal<"vec2"> => - new UnpackHalf2x16({ value }); - -export const uintToRgba8 = (value: DynoVal<"uint">): DynoVal<"vec4"> => - new UintToRgba8({ value }); - -export class SimpleCast< - Allowed extends SimpleTypes, - OutType extends SimpleTypes, - OutKey extends string, -> extends UnaryOp { - constructor({ - value, - outType, - outKey, - }: { value: DynoVal; outType: OutType; outKey: OutKey }) { - super({ a: value, outTypeFunc: () => outType, outKey }); - this.statements = ({ inputs, outputs }) => [ - `${outputs[outKey]} = ${typeLiteral(outType)}(${inputs.a});`, - ]; - } -} - -export class Bool extends SimpleCast< - "bool" | "int" | "uint" | "float", - "bool", - "bool" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "int" | "uint" | "float"> }) { - super({ value, outType: "bool", outKey: "bool" }); - } -} - -export class Int extends SimpleCast< - "bool" | "int" | "uint" | "float", - "int", - "int" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "int" | "uint" | "float"> }) { - super({ value, outType: "int", outKey: "int" }); - } -} - -export class Uint extends SimpleCast< - "bool" | "int" | "uint" | "float", - "uint", - "uint" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "int" | "uint" | "float"> }) { - super({ value, outType: "uint", outKey: "uint" }); - } -} - -export class Float extends SimpleCast< - "bool" | "int" | "uint" | "float", - "float", - "float" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "int" | "uint" | "float"> }) { - super({ value, outType: "float", outKey: "float" }); - } -} - -export class BVec2 extends SimpleCast< - "bool" | "bvec2" | "ivec2" | "uvec2" | "vec2", - "bvec2", - "bvec2" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "bvec2" | "ivec2" | "uvec2" | "vec2"> }) { - super({ value, outType: "bvec2", outKey: "bvec2" }); - } -} - -export class BVec3 extends SimpleCast< - "bool" | "bvec3" | "ivec3" | "uvec3" | "vec3", - "bvec3", - "bvec3" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "bvec3" | "ivec3" | "uvec3" | "vec3"> }) { - super({ value, outType: "bvec3", outKey: "bvec3" }); - } -} - -export class BVec4 extends SimpleCast< - "bool" | "bvec4" | "ivec4" | "uvec4" | "vec4", - "bvec4", - "bvec4" -> { - constructor({ - value, - }: { value: DynoVal<"bool" | "bvec4" | "ivec4" | "uvec4" | "vec4"> }) { - super({ value, outType: "bvec4", outKey: "bvec4" }); - } -} - -export class IVec2 extends SimpleCast< - "int" | "bvec2" | "ivec2" | "uvec2" | "vec2", - "ivec2", - "ivec2" -> { - constructor({ - value, - }: { value: DynoVal<"int" | "bvec2" | "ivec2" | "uvec2" | "vec2"> }) { - super({ value, outType: "ivec2", outKey: "ivec2" }); - } -} - -export class IVec3 extends SimpleCast< - "int" | "bvec3" | "ivec3" | "uvec3" | "vec3", - "ivec3", - "ivec3" -> { - constructor({ - value, - }: { value: DynoVal<"int" | "bvec3" | "ivec3" | "uvec3" | "vec3"> }) { - super({ value, outType: "ivec3", outKey: "ivec3" }); - } -} - -export class IVec4 extends SimpleCast< - "int" | "bvec4" | "ivec4" | "uvec4" | "vec4", - "ivec4", - "ivec4" -> { - constructor({ - value, - }: { value: DynoVal<"int" | "bvec4" | "ivec4" | "uvec4" | "vec4"> }) { - super({ value, outType: "ivec4", outKey: "ivec4" }); - } -} - -export class UVec2 extends SimpleCast< - "uint" | "ivec2" | "bvec2" | "uvec2" | "vec2", - "uvec2", - "uvec2" -> { - constructor({ - value, - }: { value: DynoVal<"uint" | "ivec2" | "bvec2" | "uvec2" | "vec2"> }) { - super({ value, outType: "uvec2", outKey: "uvec2" }); - } -} - -export class UVec3 extends SimpleCast< - "uint" | "ivec3" | "bvec3" | "uvec3" | "vec3", - "uvec3", - "uvec3" -> { - constructor({ - value, - }: { value: DynoVal<"uint" | "ivec3" | "bvec3" | "uvec3" | "vec3"> }) { - super({ value, outType: "uvec3", outKey: "uvec3" }); - } -} - -export class UVec4 extends SimpleCast< - "uint" | "ivec4" | "bvec4" | "uvec4" | "vec4", - "uvec4", - "uvec4" -> { - constructor({ - value, - }: { value: DynoVal<"uint" | "ivec4" | "bvec4" | "uvec4" | "vec4"> }) { - super({ value, outType: "uvec4", outKey: "uvec4" }); - } -} - -export class Vec2 extends SimpleCast< - "float" | "bvec2" | "ivec2" | "uvec2" | "vec2" | "vec3" | "vec4", - "vec2", - "vec2" -> { - constructor({ - value, - }: { - value: DynoVal< - "float" | "bvec2" | "ivec2" | "uvec2" | "vec2" | "vec3" | "vec4" - >; - }) { - super({ value, outType: "vec2", outKey: "vec2" }); - } -} - -export class Vec3 extends SimpleCast< - "float" | "bvec3" | "ivec3" | "uvec3" | "vec3" | "vec2" | "vec4", - "vec3", - "vec3" -> { - constructor({ - value, - }: { - value: DynoVal< - "float" | "bvec3" | "ivec3" | "uvec3" | "vec3" | "vec2" | "vec4" - >; - }) { - super({ value, outType: "vec3", outKey: "vec3" }); - } -} - -export class Vec4 extends SimpleCast< - "float" | "bvec4" | "ivec4" | "uvec4" | "vec4", - "vec4", - "vec4" -> { - constructor({ - value, - }: { value: DynoVal<"float" | "bvec4" | "ivec4" | "uvec4" | "vec4"> }) { - super({ value, outType: "vec4", outKey: "vec4" }); - } -} - -export class Mat2 extends SimpleCast< - "float" | "mat2" | "mat3" | "mat4", - "mat2", - "mat2" -> { - constructor({ - value, - }: { value: DynoVal<"float" | "mat2" | "mat3" | "mat4"> }) { - super({ value, outType: "mat2", outKey: "mat2" }); - } -} - -export class Mat3 extends SimpleCast< - "float" | "mat2" | "mat3" | "mat4", - "mat3", - "mat3" -> { - constructor({ - value, - }: { value: DynoVal<"float" | "mat2" | "mat3" | "mat4"> }) { - super({ value, outType: "mat3", outKey: "mat3" }); - } -} - -export class Mat4 extends SimpleCast< - "float" | "mat2" | "mat3" | "mat4", - "mat4", - "mat4" -> { - constructor({ - value, - }: { value: DynoVal<"float" | "mat2" | "mat3" | "mat4"> }) { - super({ value, outType: "mat4", outKey: "mat4" }); - } -} - -export class FloatBitsToInt extends UnaryOp<"float", "int", "int"> { - constructor({ value }: { value: DynoVal<"float"> }) { - super({ a: value, outKey: "int", outTypeFunc: () => "int" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.int} = floatBitsToInt(${inputs.a});`]; - }; - } -} - -export class FloatBitsToUint extends UnaryOp<"float", "uint", "uint"> { - constructor({ value }: { value: DynoVal<"float"> }) { - super({ a: value, outKey: "uint", outTypeFunc: () => "uint" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.uint} = floatBitsToUint(${inputs.a});`]; - }; - } -} - -export class IntBitsToFloat extends UnaryOp<"int", "float", "float"> { - constructor({ value }: { value: DynoVal<"int"> }) { - super({ a: value, outKey: "float", outTypeFunc: () => "float" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.float} = intBitsToFloat(${inputs.a});`]; - }; - } -} - -export class UintBitsToFloat extends UnaryOp<"uint", "float", "float"> { - constructor({ value }: { value: DynoVal<"uint"> }) { - super({ a: value, outKey: "float", outTypeFunc: () => "float" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.float} = uintBitsToFloat(${inputs.a});`]; - }; - } -} - -export class PackSnorm2x16 extends UnaryOp<"vec2", "uint", "uint"> { - constructor({ value }: { value: DynoVal<"vec2"> }) { - super({ a: value, outKey: "uint", outTypeFunc: () => "uint" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.uint} = packSnorm2x16(${inputs.a});`]; - }; - } -} - -export class UnpackSnorm2x16 extends UnaryOp<"uint", "vec2", "vec2"> { - constructor({ value }: { value: DynoVal<"uint"> }) { - super({ a: value, outKey: "vec2", outTypeFunc: () => "vec2" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.vec2} = unpackSnorm2x16(${inputs.a});`]; - }; - } -} - -export class PackUnorm2x16 extends UnaryOp<"vec2", "uint", "uint"> { - constructor({ value }: { value: DynoVal<"vec2"> }) { - super({ a: value, outKey: "uint", outTypeFunc: () => "uint" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.uint} = packUnorm2x16(${inputs.a});`]; - }; - } -} - -export class UnpackUnorm2x16 extends UnaryOp<"uint", "vec2", "vec2"> { - constructor({ value }: { value: DynoVal<"uint"> }) { - super({ a: value, outKey: "vec2", outTypeFunc: () => "vec2" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.vec2} = unpackUnorm2x16(${inputs.a});`]; - }; - } -} - -export class PackHalf2x16 extends UnaryOp<"vec2", "uint", "uint"> { - constructor({ value }: { value: DynoVal<"vec2"> }) { - super({ a: value, outKey: "uint", outTypeFunc: () => "uint" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.uint} = packHalf2x16(${inputs.a});`]; - }; - } -} - -export class UnpackHalf2x16 extends UnaryOp<"uint", "vec2", "vec2"> { - constructor({ value }: { value: DynoVal<"uint"> }) { - super({ a: value, outKey: "vec2", outTypeFunc: () => "vec2" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.vec2} = unpackHalf2x16(${inputs.a});`]; - }; - } -} - -export class UintToRgba8 extends UnaryOp<"uint", "vec4", "rgba8"> { - constructor({ value }: { value: DynoVal<"uint"> }) { - super({ a: value, outKey: "rgba8", outTypeFunc: () => "vec4" }); - this.statements = ({ inputs, outputs }) => { - return [ - `uvec4 uRgba = uvec4(${inputs.a} & 0xffu, (${inputs.a} >> 8u) & 0xffu, (${inputs.a} >> 16u) & 0xffu, (${inputs.a} >> 24u) & 0xffu);`, - `${outputs.rgba8} = vec4(uRgba) / 255.0;`, - ]; - }; - } -} diff --git a/src/dyno/logic.ts b/src/dyno/logic.ts deleted file mode 100644 index c191fc6..0000000 --- a/src/dyno/logic.ts +++ /dev/null @@ -1,434 +0,0 @@ -import { BinaryOp, TrinaryOp, UnaryOp } from "./base"; -import { - type AllIntTypes, - type BoolTypes, - type IntTypes, - type ScalarTypes, - type SimpleTypes, - type UintTypes, - type ValueTypes, - isBoolType, - isIntType, - isScalarType, - isUintType, - isVector2Type, - isVector3Type, -} from "./types"; -import { type DynoVal, valType } from "./value"; - -export const and = ( - a: DynoVal, - b: DynoVal, -): DynoVal => new And({ a, b }); -export const or = ( - a: DynoVal, - b: DynoVal, -): DynoVal => new Or({ a, b }); -export const xor = ( - a: DynoVal, - b: DynoVal, -): DynoVal => new Xor({ a, b }); -export const not = ( - a: DynoVal, -): DynoVal => new Not({ a }); - -export const lessThan = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new LessThan({ a, b }); -export const lessThanEqual = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new LessThanEqual({ a, b }); -export const greaterThan = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new GreaterThan({ a, b }); -export const greaterThanEqual = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new GreaterThanEqual({ a, b }); -export const equal = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Equal({ a, b }); -export const notEqual = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new NotEqual({ a, b }); - -export const any = ( - a: DynoVal, -): DynoVal<"bool"> => new Any({ a }); -export const all = ( - a: DynoVal, -): DynoVal<"bool"> => new All({ a }); -export const select = ( - cond: DynoVal<"bool">, - t: DynoVal, - f: DynoVal, -): DynoVal => new Select({ cond, t, f }); - -export const compXor = ( - a: DynoVal, -): DynoVal> => new CompXor({ a }); - -export class And extends BinaryOp< - T, - T, - T, - "and" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outTypeFunc: (aType: T, bType: T) => aType, outKey: "and" }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.and === "bool") { - return [`${outputs.and} = ${inputs.a} && ${inputs.b};`]; - } - return [`${outputs.and} = ${inputs.a} & ${inputs.b};`]; - }; - } -} - -export class Or extends BinaryOp< - T, - T, - T, - "or" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outTypeFunc: (aType: T, bType: T) => aType, outKey: "or" }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.or === "bool") { - return [`${outputs.or} = ${inputs.a} || ${inputs.b};`]; - } - return [`${outputs.or} = ${inputs.a} | ${inputs.b};`]; - }; - } -} - -export class Xor extends BinaryOp< - T, - T, - T, - "xor" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outTypeFunc: (aType: T, bType: T) => aType, outKey: "xor" }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.xor === "bool") { - return [`${outputs.xor} = ${inputs.a} ^^ ${inputs.b};`]; - } - return [`${outputs.xor} = ${inputs.a} ^ ${inputs.b};`]; - }; - } -} - -export class Not extends UnaryOp< - T, - T, - "not" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outTypeFunc: (aType: T) => aType, outKey: "not" }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.not === "bool") { - return [`${outputs.not} = !${inputs.a};`]; - } - return [`${outputs.not} = not(${inputs.a});`]; - }; - } -} - -export class LessThan extends BinaryOp< - T, - T, - CompareOutput, - "lessThan" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ - a, - b, - outTypeFunc: (aType: T, bType: T) => compareOutputType(aType, "lessThan"), - outKey: "lessThan", - }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.lessThan === "bool") { - return [`${outputs.lessThan} = ${inputs.a} < ${inputs.b};`]; - } - return [`${outputs.lessThan} = lessThan(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class LessThanEqual extends BinaryOp< - T, - T, - CompareOutput, - "lessThanEqual" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ - a, - b, - outTypeFunc: (aType: T, bType: T) => - compareOutputType(aType, "lessThanEqual"), - outKey: "lessThanEqual", - }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.lessThanEqual === "bool") { - return [`${outputs.lessThanEqual} = ${inputs.a} <= ${inputs.b};`]; - } - return [ - `${outputs.lessThanEqual} = lessThanEqual(${inputs.a}, ${inputs.b});`, - ]; - }; - } -} - -export class GreaterThan extends BinaryOp< - T, - T, - CompareOutput, - "greaterThan" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ - a, - b, - outTypeFunc: (aType: T, bType: T) => - compareOutputType(aType, "greaterThan"), - outKey: "greaterThan", - }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.greaterThan === "bool") { - return [`${outputs.greaterThan} = ${inputs.a} > ${inputs.b};`]; - } - return [ - `${outputs.greaterThan} = greaterThan(${inputs.a}, ${inputs.b});`, - ]; - }; - } -} - -export class GreaterThanEqual extends BinaryOp< - T, - T, - CompareOutput, - "greaterThanEqual" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ - a, - b, - outTypeFunc: (aType: T, bType: T) => - compareOutputType(aType, "greaterThanEqual"), - outKey: "greaterThanEqual", - }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.greaterThanEqual === "bool") { - return [`${outputs.greaterThanEqual} = ${inputs.a} >= ${inputs.b};`]; - } - return [ - `${outputs.greaterThanEqual} = greaterThanEqual(${inputs.a}, ${inputs.b});`, - ]; - }; - } -} - -export class Equal extends BinaryOp< - T, - T, - EqualOutput, - "equal" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outTypeFunc: equalOutputType, outKey: "equal" }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.equal === "bool") { - return [`${outputs.equal} = ${inputs.a} == ${inputs.b};`]; - } - return [`${outputs.equal} = equal(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class NotEqual extends BinaryOp< - T, - T, - NotEqualOutput, - "notEqual" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outTypeFunc: notEqualOutputType, outKey: "notEqual" }); - this.statements = ({ inputs, outputs }) => { - if (this.outTypes.notEqual === "bool") { - return [`${outputs.notEqual} = ${inputs.a} != ${inputs.b};`]; - } - return [`${outputs.notEqual} = notEqual(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class Any extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outTypeFunc: (aType: T) => "bool", outKey: "any" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.any} = any(${inputs.a});`]; - }; - } -} - -export class All extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outTypeFunc: (aType: T) => "bool", outKey: "all" }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.all} = all(${inputs.a});`]; - }; - } -} - -export class Select extends TrinaryOp< - "bool", - T, - T, - T, - "select" -> { - constructor({ - cond, - t, - f, - }: { cond: DynoVal<"bool">; t: DynoVal; f: DynoVal }) { - super({ - a: cond, - b: t, - c: f, - outKey: "select", - outTypeFunc: (aType: "bool", bType: T, cType: T) => bType, - }); - this.statements = ({ inputs, outputs }) => { - const { a: cond, b: t, c: f } = inputs; - return [`${outputs.select} = (${cond}) ? (${t}) : (${f});`]; - }; - } -} - -type CompareOutput = T extends ScalarTypes - ? "bool" - : T extends "ivec2" | "uvec2" | "vec2" - ? "bvec2" - : T extends "ivec3" | "uvec3" | "vec3" - ? "bvec3" - : T extends "ivec4" | "uvec4" | "vec4" - ? "bvec4" - : never; - -function compareOutputType( - type: T, - operator: string, -): CompareOutput { - if (isScalarType(type)) { - return "bool" as CompareOutput; - } - if (type === "ivec2" || type === "uvec2" || type === "vec2") { - return "bvec2" as CompareOutput; - } - if (type === "ivec3" || type === "uvec3" || type === "vec3") { - return "bvec3" as CompareOutput; - } - if (type === "ivec4" || type === "uvec4" || type === "vec4") { - return "bvec4" as CompareOutput; - } - throw new Error(`Invalid ${operator} type: ${type}`); -} - -type EqualOutput = A extends ScalarTypes - ? "bool" - : A extends BoolTypes - ? A - : A extends "ivec2" | "uvec2" | "vec2" - ? "bvec2" - : A extends "ivec3" | "uvec3" | "vec3" - ? "bvec3" - : A extends "ivec4" | "uvec4" | "vec4" - ? "bvec4" - : never; - -function equalOutputType( - type: A, - operator = "equal", -): EqualOutput { - if (isScalarType(type)) { - return "bool" as EqualOutput; - } - if (isBoolType(type)) { - return type as EqualOutput; - } - if (type === "ivec2" || type === "uvec2" || type === "vec2") { - return "bvec2" as EqualOutput; - } - if (type === "ivec3" || type === "uvec3" || type === "vec3") { - return "bvec3" as EqualOutput; - } - if (type === "ivec4" || type === "uvec4" || type === "vec4") { - return "bvec4" as EqualOutput; - } - throw new Error(`Invalid ${operator} type: ${type}`); -} - -type NotEqualOutput = EqualOutput; - -function notEqualOutputType( - type: A, -): NotEqualOutput { - return equalOutputType(type, "notEqual"); -} - -type CompXorOutput = A extends BoolTypes - ? "bool" - : A extends IntTypes - ? "int" - : A extends UintTypes - ? "uint" - : never; - -function compXorOutputType( - type: A, -): CompXorOutput { - if (isBoolType(type)) { - return "bool" as CompXorOutput; - } - if (isIntType(type)) { - return "int" as CompXorOutput; - } - if (isUintType(type)) { - return "uint" as CompXorOutput; - } - throw new Error(`Invalid compXor type: ${type}`); -} - -export class CompXor extends UnaryOp< - T, - CompXorOutput, - "compXor" -> { - constructor({ a }: { a: DynoVal }) { - const outType = compXorOutputType(valType(a)); - super({ a, outTypeFunc: (aType: T) => outType, outKey: "compXor" }); - this.statements = ({ inputs, outputs }) => { - if (isScalarType(this.outTypes.compXor)) { - return [`${outputs.compXor} = ${inputs.a};`]; - } - const components = isVector2Type(outType) - ? ["x", "y"] - : isVector3Type(outType) - ? ["x", "y", "z"] - : ["x", "y", "z", "w"]; - const operands = components.map((c) => `${inputs.a}.${c}`); - const operator = isBoolType(outType) ? "^^" : "^"; - return [`${outputs.compXor} = ${operands.join(` ${operator} `)};`]; - }; - } -} diff --git a/src/dyno/math.ts b/src/dyno/math.ts deleted file mode 100644 index d5e6031..0000000 --- a/src/dyno/math.ts +++ /dev/null @@ -1,534 +0,0 @@ -import { BinaryOp, Dyno, TrinaryOp, UnaryOp } from "./base"; -import { - type AddOutput, - type ClampOutput, - type DivOutput, - type IModOutput, - type IsInfOutput, - type IsNanOutput, - type MaxOutput, - type MinOutput, - type MixOutput, - type ModOutput, - type MulOutput, - type SmoothstepOutput, - type StepOutput, - type SubOutput, - absOutputType, - addOutputType, - ceilOutputType, - clampOutputType, - divOutputType, - exp2OutputType, - expOutputType, - floorOutputType, - fractOutputType, - imodOutputType, - inversesqrtOutputType, - isInfOutputType, - isNanOutputType, - log2OutputType, - logOutputType, - maxOutputType, - minOutputType, - mixOutputType, - modOutputType, - modfOutputType, - mulOutputType, - negOutputType, - powOutputType, - roundOutputType, - signOutputType, - smoothstepOutputType, - sqrOutputType, - sqrtOutputType, - stepOutputType, - subOutputType, - truncOutputType, -} from "./mathTypes"; -import type { - AllIntTypes, - AllSignedTypes, - AllValueTypes, - BoolTypes, - FloatTypes, - SignedTypes, - ValueTypes, -} from "./types"; -import { type DynoVal, valType } from "./value"; - -export const add = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Add({ a, b }); -export const sub = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Sub({ a, b }); -export const mul = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Mul({ a, b }); -export const div = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Div({ a, b }); -export const imod = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new IMod({ a, b }); -export const mod = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Mod({ a, b }); -export const modf = (a: DynoVal) => - new Modf({ a }).outputs; - -export const neg = (a: DynoVal): DynoVal => - new Neg({ a }); -export const abs = (a: DynoVal): DynoVal => - new Abs({ a }); -export const sign = (a: DynoVal): DynoVal => - new Sign({ a }); -export const floor = (a: DynoVal): DynoVal => - new Floor({ a }); -export const ceil = (a: DynoVal): DynoVal => - new Ceil({ a }); -export const trunc = (a: DynoVal): DynoVal => - new Trunc({ a }); -export const round = (a: DynoVal): DynoVal => - new Round({ a }); -export const fract = (a: DynoVal): DynoVal => - new Fract({ a }); - -export const pow = ( - a: DynoVal, - b: DynoVal, -): DynoVal => new Pow({ a, b }); -export const exp = (a: DynoVal): DynoVal => - new Exp({ a }); -export const exp2 = (a: DynoVal): DynoVal => - new Exp2({ a }); -export const log = (a: DynoVal): DynoVal => - new Log({ a }); -export const log2 = (a: DynoVal): DynoVal => - new Log2({ a }); -export const sqr = (a: DynoVal): DynoVal => - new Sqr({ a }); -export const sqrt = (a: DynoVal): DynoVal => - new Sqrt({ a }); -export const inversesqrt = (a: DynoVal): DynoVal => - new InverseSqrt({ a }); - -export const min = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Min({ a, b }); -export const max = ( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Max({ a, b }); -export const clamp = ( - a: DynoVal, - min: DynoVal, - max: DynoVal, -): DynoVal> => new Clamp({ a, min, max }); -export const mix = ( - a: DynoVal, - b: DynoVal, - t: DynoVal, -): DynoVal> => new Mix({ a, b, t }); -export const step = ( - edge: DynoVal, - x: DynoVal, -): DynoVal> => new Step({ edge, x }); -export const smoothstep = ( - edge0: DynoVal, - edge1: DynoVal, - x: DynoVal, -): DynoVal> => - new Smoothstep({ edge0, edge1, x }); - -export const isNan = ( - a: DynoVal, -): DynoVal> => new IsNan({ a }); -export const isInf = ( - a: DynoVal, -): DynoVal> => new IsInf({ a }); - -export class Add< - A extends AllValueTypes, - B extends AllValueTypes, -> extends BinaryOp, "sum"> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "sum", outTypeFunc: addOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.sum} = ${inputs.a} + ${inputs.b};`]; - }; - } -} - -export class Sub< - A extends AllValueTypes, - B extends AllValueTypes, -> extends BinaryOp, "difference"> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "difference", outTypeFunc: subOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.difference} = ${inputs.a} - ${inputs.b};`]; - }; - } -} - -export class Mul< - A extends AllValueTypes, - B extends AllValueTypes, -> extends BinaryOp, "product"> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "product", outTypeFunc: mulOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.product} = ${inputs.a} * ${inputs.b};`]; - }; - } -} - -export class Div< - A extends AllValueTypes, - B extends AllValueTypes, -> extends BinaryOp, "quotient"> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "quotient", outTypeFunc: divOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.quotient} = ${inputs.a} / ${inputs.b};`]; - }; - } -} - -export class IMod< - A extends AllIntTypes, - B extends AllIntTypes, -> extends BinaryOp, "remainder"> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "remainder", outTypeFunc: imodOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.remainder} = ${inputs.a} % ${inputs.b};`]; - }; - } -} - -export class Mod extends BinaryOp< - A, - B, - ModOutput, - "remainder" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "remainder", outTypeFunc: modOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.remainder} = mod(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class Modf extends Dyno< - { a: A }, - { fract: A; integer: A } -> { - constructor({ a }: { a: DynoVal }) { - const inTypes = { a: valType(a) }; - const outType = modfOutputType(inTypes.a); - const outTypes = { - fract: outType, - integer: outType, - }; - super({ inTypes, outTypes, inputs: { a } }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.fract} = modf(${inputs.a}, ${outputs.integer});`]; - }; - } -} - -export class Neg extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "neg", outTypeFunc: negOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.neg} = -${inputs.a};`]; - }; - } -} - -export class Abs extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "abs", outTypeFunc: absOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.abs} = abs(${inputs.a});`]; - }; - } -} - -export class Sign extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "sign", outTypeFunc: signOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.sign} = sign(${inputs.a});`]; - }; - } -} - -export class Floor extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "floor", outTypeFunc: floorOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.floor} = floor(${inputs.a});`]; - }; - } -} - -export class Ceil extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "ceil", outTypeFunc: ceilOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.ceil} = ceil(${inputs.a});`]; - }; - } -} - -export class Trunc extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "trunc", outTypeFunc: truncOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.trunc} = trunc(${inputs.a});`]; - }; - } -} - -export class Round extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "round", outTypeFunc: roundOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.round} = round(${inputs.a});`]; - }; - } -} - -export class Fract extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "fract", outTypeFunc: fractOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.fract} = fract(${inputs.a});`]; - }; - } -} - -export class Pow extends BinaryOp { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "power", outTypeFunc: powOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.power} = pow(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class Exp extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "exp", outTypeFunc: expOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.exp} = exp(${inputs.a});`]; - }; - } -} - -export class Exp2 extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "exp2", outTypeFunc: exp2OutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.exp2} = exp2(${inputs.a});`]; - }; - } -} - -export class Log extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "log", outTypeFunc: logOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.log} = log(${inputs.a});`]; - }; - } -} - -export class Log2 extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "log2", outTypeFunc: log2OutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.log2} = log2(${inputs.a});`]; - }; - } -} - -export class Sqr extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "sqr", outTypeFunc: sqrOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.sqr} = ${inputs.a} * ${inputs.a};`]; - }; - } -} - -export class Sqrt extends UnaryOp { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "sqrt", outTypeFunc: sqrtOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.sqrt} = sqrt(${inputs.a});`]; - }; - } -} - -export class InverseSqrt extends UnaryOp< - A, - A, - "inversesqrt" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "inversesqrt", outTypeFunc: inversesqrtOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.inversesqrt} = inversesqrt(${inputs.a});`]; - }; - } -} - -export class Min extends BinaryOp< - A, - B, - MinOutput, - "min" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "min", outTypeFunc: minOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.min} = min(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class Max extends BinaryOp< - A, - B, - MaxOutput, - "max" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "max", outTypeFunc: maxOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.max} = max(${inputs.a}, ${inputs.b});`]; - }; - } -} - -export class Clamp< - A extends ValueTypes, - MinMax extends ValueTypes, -> extends TrinaryOp, "clamp"> { - constructor({ - a, - min, - max, - }: { a: DynoVal; min: DynoVal; max: DynoVal }) { - super({ - a, - b: min, - c: max, - outKey: "clamp", - outTypeFunc: clampOutputType, - }); - this.statements = ({ inputs, outputs }) => { - const { a, b: min, c: max } = inputs; - return [`${outputs.clamp} = clamp(${a}, ${min}, ${max});`]; - }; - } -} - -export class Mix< - A extends FloatTypes, - T extends FloatTypes | BoolTypes, -> extends TrinaryOp, "mix"> { - constructor({ a, b, t }: { a: DynoVal; b: DynoVal; t: DynoVal }) { - super({ a, b, c: t, outKey: "mix", outTypeFunc: mixOutputType }); - this.statements = ({ inputs, outputs }) => { - const { a, b, c: t } = inputs; - return [`${outputs.mix} = mix(${a}, ${b}, ${t});`]; - }; - } -} - -export class Step< - Edge extends FloatTypes, - X extends FloatTypes, -> extends BinaryOp, "step"> { - constructor({ edge, x }: { edge: DynoVal; x: DynoVal }) { - super({ - a: edge, - b: x, - outKey: "step", - outTypeFunc: stepOutputType, - }); - this.statements = ({ inputs, outputs }) => { - const { a: edge, b: x } = inputs; - return [`${outputs.step} = step(${edge}, ${x});`]; - }; - } -} - -export class Smoothstep< - X extends FloatTypes, - Edge extends X | "float", -> extends TrinaryOp< - Edge, - Edge, - X, - SmoothstepOutput, - "smoothstep" -> { - constructor({ - edge0, - edge1, - x, - }: { edge0: DynoVal; edge1: DynoVal; x: DynoVal }) { - super({ - a: edge0, - b: edge1, - c: x, - outKey: "smoothstep", - outTypeFunc: smoothstepOutputType, - }); - this.statements = ({ inputs, outputs }) => { - const { a: edge0, b: edge1, c: x } = inputs; - return [`${outputs.smoothstep} = smoothstep(${edge0}, ${edge1}, ${x});`]; - }; - } -} - -export class IsNan extends UnaryOp< - A, - IsNanOutput, - "isNan" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "isNan", outTypeFunc: isNanOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.isNan} = isNan(${inputs.a});`]; - }; - } -} - -export class IsInf extends UnaryOp< - A, - IsInfOutput, - "isInf" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "isInf", outTypeFunc: isInfOutputType }); - this.statements = ({ inputs, outputs }) => { - return [`${outputs.isInf} = isInf(${inputs.a});`]; - }; - } -} diff --git a/src/dyno/mathTypes.ts b/src/dyno/mathTypes.ts deleted file mode 100644 index 13bcada..0000000 --- a/src/dyno/mathTypes.ts +++ /dev/null @@ -1,717 +0,0 @@ -import { - type AllFloatTypes, - type AllIntTypes, - type AllSignedTypes, - type AllValueTypes, - type BaseType, - type BoolTypes, - type FloatTypes, - type IntTypes, - type SignedTypes, - type UintTypes, - type ValueTypes, - isAllFloatType, - isFloatType, - isIntType, - isMat2, - isMat3, - isMat4, - isUintType, -} from "./types"; - -export type AddOutput< - A extends AllValueTypes, - B extends AllValueTypes, -> = BaseType & - (A extends B - ? A - : A extends "int" - ? B extends IntTypes - ? B - : never - : B extends "int" - ? A extends IntTypes - ? A - : never - : A extends "uint" - ? B extends UintTypes - ? B - : never - : B extends "uint" - ? A extends UintTypes - ? A - : never - : A extends "float" - ? B extends AllFloatTypes - ? B - : never - : B extends "float" - ? A extends AllFloatTypes - ? A - : never - : never); - -export type SubOutput< - A extends AllValueTypes, - B extends AllValueTypes, -> = AddOutput; - -export type MulOutput< - A extends AllValueTypes, - B extends AllValueTypes, -> = BaseType & - (A extends "int" - ? B extends IntTypes - ? B - : never - : B extends "int" - ? A extends IntTypes - ? A - : never - : A extends "uint" - ? B extends UintTypes - ? B - : never - : B extends "uint" - ? A extends UintTypes - ? A - : never - : A extends "float" - ? B extends AllFloatTypes - ? B - : never - : B extends "float" - ? A extends AllFloatTypes - ? A - : never - : A extends IntTypes - ? B extends A - ? A - : never - : B extends IntTypes - ? A extends B - ? A - : never - : A extends UintTypes - ? B extends A - ? A - : never - : B extends UintTypes - ? A extends B - ? A - : never - : // Vector * Matrix/Vector - A extends "vec2" - ? B extends "vec2" | "mat2" | "mat2x2" - ? "vec2" - : B extends "mat3x2" - ? "vec3" - : B extends "mat4x2" - ? "vec4" - : never - : A extends "vec3" - ? B extends "mat2x3" - ? "vec2" - : B extends "vec3" | "mat3" | "mat3x3" - ? "vec3" - : B extends "mat4x3" - ? "vec4" - : never - : A extends "vec4" - ? B extends "mat2x4" - ? "vec2" - : B extends "mat3x4" - ? "vec3" - : B extends "vec4" | "mat4" | "mat4x4" - ? "vec4" - : never - : // Matrix * Vector - B extends "vec2" - ? A extends "mat2" | "mat2x2" - ? "vec2" - : A extends "mat2x3" - ? "vec3" - : A extends "mat2x4" - ? "vec4" - : never - : B extends "vec3" - ? A extends "mat3x2" - ? "vec2" - : A extends "mat3" | "mat3x3" - ? "vec3" - : A extends "mat3x4" - ? "vec4" - : never - : B extends "vec4" - ? A extends "mat4x2" - ? "vec2" - : A extends "mat4x3" - ? "vec3" - : A extends "mat4" | "mat4x4" - ? "vec4" - : never - : // Matrix * Matrix: mat{Acols}x{Arows} * mat{Bcols}x{Brows} => mat{Bcols}x{Arows} - A extends "mat2" | "mat2x2" // Acols = 2 => Brows = 2 - ? B extends "mat2" | "mat2x2" - ? "mat2" - : B extends "mat3x2" - ? "mat3x2" - : B extends "mat4x2" - ? "mat4x2" - : never - : A extends "mat2x3" // Acols = 2 => Brows = 2 - ? B extends "mat2" | "mat2x2" - ? "mat2x3" - : B extends "mat3x2" - ? "mat3" - : B extends "mat4x2" - ? "mat4x3" - : never - : A extends "mat2x4" // Acols = 2 => Brows = 2 - ? B extends "mat2" | "mat2x2" - ? "mat2x4" - : B extends "mat3x2" - ? "mat3x4" - : B extends "mat4x2" - ? "mat4" - : never - : A extends "mat3x2" // Acols = 3 => Brows = 3 - ? B extends "mat2x3" - ? "mat2" - : B extends "mat3" | "mat3x3" - ? "mat3x2" - : B extends "mat4x3" - ? "mat4x2" - : never - : A extends "mat3" | "mat3x3" // Acols = 3 => Brows = 3 - ? B extends "mat2x3" - ? "mat2x3" - : B extends "mat3" | "mat3x3" - ? "mat3" - : B extends "mat4x3" - ? "mat4x3" - : never - : A extends "mat3x4" // Acols = 3 => Brows = 3 - ? B extends "mat2x3" - ? "mat2x4" - : B extends "mat3" | "mat3x3" - ? "mat3x4" - : B extends "mat4x3" - ? "mat4" - : never - : A extends "mat4x2" // Acols = 4 => Brows = 4 - ? B extends "mat2x4" - ? "mat2" - : B extends "mat3x4" - ? "mat3x2" - : B extends - | "mat4" - | "mat4x4" - ? "mat4x2" - : never - : A extends "mat4x3" // Acols = 4 => Brows = 4 - ? B extends "mat2x4" - ? "mat2x3" - : B extends "mat3x4" - ? "mat3" - : B extends - | "mat4" - | "mat4x4" - ? "mat4x3" - : never - : A extends "mat4" | "mat4x4" // Acols = 4 => Brows = 4 - ? B extends "mat2x4" - ? "mat2x4" - : B extends "mat3x4" - ? "mat3x4" - : B extends - | "mat4" - | "mat4x4" - ? "mat4" - : never - : never); - -export type DivOutput< - A extends AllValueTypes, - B extends AllValueTypes, -> = AddOutput; - -export type IModOutput< - A extends AllIntTypes, - B extends AllIntTypes, -> = BaseType & - (A extends B - ? A - : A extends "int" - ? B extends IntTypes - ? B - : never - : B extends "int" - ? A extends IntTypes - ? A - : never - : A extends "uint" - ? B extends UintTypes - ? B - : never - : B extends "uint" - ? A extends UintTypes - ? A - : never - : never); - -export type ModOutput = BaseType & - (A extends B ? A : B extends "float" ? A : never); - -export type PowOutput = BaseType & - (A extends B ? A : never); - -export type MinOutput = BaseType & - (A extends B - ? A - : B extends "float" - ? A extends FloatTypes - ? A - : never - : B extends "int" - ? A extends IntTypes - ? A - : never - : B extends "uint" - ? A extends UintTypes - ? A - : never - : never); -export type MaxOutput = MinOutput< - A, - B ->; -export type ClampOutput = BaseType & - (B extends "float" - ? A extends FloatTypes - ? A - : never - : B extends "int" - ? A extends IntTypes - ? A - : never - : B extends "uint" - ? A extends UintTypes - ? A - : never - : never); -export type MixOutput< - A extends FloatTypes, - T extends FloatTypes | BoolTypes, -> = BaseType & - (T extends A - ? A - : T extends "float" - ? A - : T extends "bool" - ? A extends "float" - ? A - : never - : T extends "bvec2" - ? A extends "vec2" - ? A - : never - : T extends "bvec3" - ? A extends "vec3" - ? A - : never - : T extends "bvec4" - ? A extends "vec4" - ? A - : never - : never); -export type StepOutput = BaseType & - (A extends B ? B : A extends "float" ? B : never); -export type SmoothstepOutput< - A extends FloatTypes, - B extends FloatTypes, - C extends FloatTypes, -> = BaseType & - (A extends B ? (A extends C ? C : A extends "float" ? C : never) : never); - -export type IsNanOutput = BaseType & - (A extends "float" - ? "bool" - : A extends "vec2" - ? "bvec2" - : A extends "vec3" - ? "bvec3" - : A extends "vec4" - ? "bvec4" - : never); -export type IsInfOutput = IsNanOutput; - -// // Run-time type helper functions - -export function addOutputType( - a: A, - b: B, - operation = "add", -): AddOutput { - const error = () => { - throw new Error(`Invalid ${operation} types: ${a}, ${b}`); - }; - // @ts-ignore - if (a === b) return a as AddOutput; - if (a === "int") { - if (isIntType(b)) return b as AddOutput; - error(); - } - if (b === "int") { - if (isIntType(a)) return a as AddOutput; - error(); - } - if (a === "uint") { - if (isUintType(b)) return b as AddOutput; - error(); - } - if (b === "uint") { - if (isUintType(a)) return a as AddOutput; - error(); - } - if (a === "float") { - if (isAllFloatType(b)) return b as AddOutput; - error(); - } - if (b === "float") { - if (isAllFloatType(a)) return a as AddOutput; - error(); - } - throw new Error(`Invalid ${operation} types: ${a}, ${b}`); -} - -export function subOutputType( - a: A, - b: B, -): SubOutput { - return addOutputType(a, b, "sub"); -} - -export function mulOutputType( - a: A, - b: B, -): MulOutput { - const error = () => { - throw new Error(`Invalid mul types: ${a}, ${b}`); - }; - const result = (value: unknown) => value as MulOutput; - if (a === "int") { - if (isIntType(b)) return result(b); - error(); - } - if (b === "int") { - if (isIntType(a)) return result(a); - error(); - } - if (a === "uint") { - if (isUintType(b)) return result(b); - error(); - } - if (b === "uint") { - if (isUintType(a)) return result(a); - error(); - } - if (a === "float") { - if (isAllFloatType(b)) return result(b); - error(); - } - if (b === "float") { - if (isAllFloatType(a)) return result(a); - error(); - } - if (isIntType(a) || isUintType(a) || isIntType(b) || isUintType(b)) { - // @ts-ignore - if (a === b) return result(a); - error(); - } - // Vector * Matrix/Vector - if (a === "vec2") { - if (b === "vec2" || isMat2(b)) return result("vec2"); - if (b === "mat3x2") return result("vec3"); - if (b === "mat4x2") return result("vec4"); - error(); - } - if (a === "vec3") { - if (b === "mat2x3") return result("vec2"); - if (b === "vec3" || isMat3(b)) return result("vec3"); - if (b === "mat4x3") return result("vec4"); - error(); - } - if (a === "vec4") { - if (b === "mat2x4") return result("vec2"); - if (b === "mat3x4") return result("vec3"); - if (b === "vec4" || isMat4(b)) return result("vec4"); - error(); - } - // Matrix * Vector - if (b === "vec2") { - if (isMat2(a)) return result("vec2"); - if (a === "mat2x3") return result("vec3"); - if (a === "mat2x4") return result("vec4"); - error(); - } - if (b === "vec3") { - if (a === "mat3x2") return result("vec2"); - if (isMat3(a)) return result("vec3"); - if (a === "mat3x4") return result("vec4"); - error(); - } - if (b === "vec4") { - if (a === "mat4x2") return result("vec2"); - if (a === "mat4x3") return result("vec3"); - if (isMat4(a)) return result("vec4"); - error(); - } - // Matrix * Matrix: mat{Acols}x{Arows} * mat{Bcols}x{Brows} => mat{Bcols}x{Arows} - if (isMat2(a)) { - if (isMat2(b)) return result("mat2"); - if (b === "mat3x2") return result("mat3x2"); - if (b === "mat4x2") return result("mat4x2"); - error(); - } - if (a === "mat2x3") { - if (isMat2(b)) return result("mat2x3"); - if (b === "mat3x2") return result("mat3"); - if (b === "mat4x2") return result("mat4x3"); - error(); - } - if (a === "mat2x4") { - if (isMat2(b)) return result("mat2x4"); - if (b === "mat3x2") return result("mat3x4"); - if (b === "mat4x2") return result("mat4"); - error(); - } - if (a === "mat3x2") { - if (b === "mat2x3") return result("mat2"); - if (isMat3(b)) return result("mat3x2"); - if (b === "mat4x3") return result("mat4x2"); - error(); - } - if (isMat3(a)) { - if (b === "mat2x3") return result("mat2x3"); - if (isMat3(b)) return result("mat3"); - if (b === "mat4x3") return result("mat4x3"); - error(); - } - if (a === "mat3x4") { - if (b === "mat2x3") return result("mat2x4"); - if (isMat3(b)) return result("mat3x4"); - if (b === "mat4x3") return result("mat4"); - error(); - } - if (a === "mat4x2") { - if (b === "mat2x4") return result("mat2"); - if (b === "mat3x4") return result("mat3x2"); - if (isMat4(b)) return result("mat4x2"); - error(); - } - if (a === "mat4x3") { - if (b === "mat2x4") return result("mat2x3"); - if (b === "mat3x4") return result("mat3"); - if (isMat4(b)) return result("mat4x3"); - error(); - } - if (isMat4(a)) { - if (b === "mat2x4") return result("mat2x4"); - if (b === "mat3x4") return result("mat3x4"); - if (isMat4(b)) return result("mat4"); - error(); - } - throw new Error(`Invalid mul types: ${a}, ${b}`); -} - -export function divOutputType( - a: A, - b: B, -): DivOutput { - return addOutputType(a, b, "div"); -} - -export function imodOutputType( - a: A, - b: B, -): IModOutput { - // @ts-ignore - if (a === b) return a as IModOutput; - if (a === "int") { - if (isIntType(b)) return b as IModOutput; - } else if (b === "int") { - if (isIntType(a)) return a as IModOutput; - } else if (a === "uint") { - if (isUintType(b)) return b as IModOutput; - } else if (b === "uint") { - if (isUintType(a)) return a as IModOutput; - } - throw new Error(`Invalid imod types: ${a}, ${b}`); -} - -export function modOutputType( - a: A, - b: B, -): ModOutput { - // @ts-ignore - if (a === b || b === "float") return a as ModOutput; - throw new Error(`Invalid mod types: ${a}, ${b}`); -} - -export function modfOutputType(a: A): A { - return a; -} - -export function negOutputType(a: A): A { - return a; -} - -export function absOutputType(a: A): A { - return a; -} - -export function signOutputType(a: A): A { - return a; -} - -export function floorOutputType(a: A): A { - return a; -} - -export function ceilOutputType(a: A): A { - return a; -} - -export function truncOutputType(a: A): A { - return a; -} - -export function roundOutputType(a: A): A { - return a; -} - -export function fractOutputType(a: A): A { - return a; -} - -export function powOutputType(a: A): A { - return a; -} - -export function expOutputType(a: A): A { - return a; -} - -export function exp2OutputType(a: A): A { - return a; -} - -export function logOutputType(a: A): A { - return a; -} - -export function log2OutputType(a: A): A { - return a; -} - -export function sqrOutputType(a: A): A { - return a; -} - -export function sqrtOutputType(a: A): A { - return a; -} - -export function inversesqrtOutputType(a: A): A { - return a; -} - -export function minOutputType( - a: A, - b: B, - operation = "min", -): MinOutput { - // @ts-ignore - if (a === b) return a as MinOutput; - if (b === "float") { - if (isFloatType(a)) return a as MinOutput; - } else if (b === "int") { - if (isIntType(a)) return a as MinOutput; - } else if (b === "uint") { - if (isUintType(a)) return a as MinOutput; - } - throw new Error(`Invalid ${operation} types: ${a}, ${b}`); -} - -export function maxOutputType( - a: A, - b: B, -): MaxOutput { - return minOutputType(a, b, "max"); -} - -export function clampOutputType( - a: A, - b: B, - _c: B, -): ClampOutput { - if (b === "float") { - if (isFloatType(a)) return a as ClampOutput; - } else if (b === "int") { - if (isIntType(a)) return a as ClampOutput; - } else if (b === "uint") { - if (isUintType(a)) return a as ClampOutput; - } - throw new Error(`Invalid clamp types: ${a}, ${b}`); -} - -export function mixOutputType< - A extends FloatTypes, - C extends FloatTypes | BoolTypes, ->(a: A, b: A, c: C): MixOutput { - // @ts-ignore - if (c === a) return a as MixOutput; - if (c === "float") return a as MixOutput; - if (c === "bool" && a === "float") return a as MixOutput; - if (c === "bvec2" && a === "vec2") return a as MixOutput; - if (c === "bvec3" && a === "vec3") return a as MixOutput; - if (c === "bvec4" && a === "vec4") return a as MixOutput; - throw new Error(`Invalid mix types: ${a}, ${b}, ${c}`); -} - -export function stepOutputType( - a: A, - b: B, -): StepOutput { - // @ts-ignore - if (a === b || b === "float") return b as StepOutput; - throw new Error(`Invalid step types: ${a}, ${b}`); -} - -export function smoothstepOutputType< - A extends FloatTypes, - B extends FloatTypes, - C extends FloatTypes, ->(a: A, b: B, c: C): SmoothstepOutput { - // @ts-ignore - if (a === b) { - if (a === c || a === "float") return c as SmoothstepOutput; - } - throw new Error(`Invalid smoothstep types: ${a}, ${b}, ${c}`); -} - -export function isNanOutputType( - a: A, - operation = "isNan", -): IsNanOutput { - if (a === "float") return "bool" as IsNanOutput; - if (a === "vec2") return "bvec2" as IsNanOutput; - if (a === "vec3") return "bvec3" as IsNanOutput; - if (a === "vec4") return "bvec4" as IsNanOutput; - throw new Error(`Invalid ${operation} types: ${a}`); -} - -export function isInfOutputType(a: A): IsInfOutput { - return isNanOutputType(a, "isInf"); -} diff --git a/src/dyno/output.ts b/src/dyno/output.ts deleted file mode 100644 index 379a8d9..0000000 --- a/src/dyno/output.ts +++ /dev/null @@ -1,78 +0,0 @@ -import * as THREE from "three"; -import { Dyno, unindentLines } from "./base"; -import { Gsplat, defineGsplat } from "./splats"; -import { - DynoOutput, - type DynoVal, - type DynoValue, - type HasDynoOut, -} from "./value"; - -export const outputPackedSplat = ( - gsplat: DynoVal, - rgbMinMaxLnScaleMinMax: DynoVal<"vec4">, -) => new OutputPackedSplat({ gsplat, rgbMinMaxLnScaleMinMax }); -export const outputRgba8 = (rgba8: DynoVal<"vec4">) => - new OutputRgba8({ rgba8 }); - -export class OutputPackedSplat - extends Dyno< - { gsplat: typeof Gsplat; rgbMinMaxLnScaleMinMax: "vec4" }, - { output: "uvec4" } - > - implements HasDynoOut<"uvec4"> -{ - constructor({ - gsplat, - rgbMinMaxLnScaleMinMax, - }: { - gsplat?: DynoVal; - rgbMinMaxLnScaleMinMax?: DynoVal<"vec4">; - }) { - super({ - inTypes: { gsplat: Gsplat, rgbMinMaxLnScaleMinMax: "vec4" }, - inputs: { gsplat, rgbMinMaxLnScaleMinMax }, - globals: () => [defineGsplat], - statements: ({ inputs, outputs }) => { - const { output } = outputs; - if (!output) { - return []; - } - const { gsplat, rgbMinMaxLnScaleMinMax } = inputs; - if (gsplat) { - return unindentLines(` - if (isGsplatActive(${gsplat}.flags)) { - ${output} = packSplatEncoding(${gsplat}.center, ${gsplat}.scales, ${gsplat}.quaternion, ${gsplat}.rgba, ${rgbMinMaxLnScaleMinMax}); - } else { - ${output} = uvec4(0u, 0u, 0u, 0u); - } - `); - } - return [`${output} = uvec4(0u, 0u, 0u, 0u);`]; - }, - }); - } - - dynoOut(): DynoValue<"uvec4"> { - return new DynoOutput(this, "output"); - } -} - -export class OutputRgba8 - extends Dyno<{ rgba8: "vec4" }, { rgba8: "vec4" }> - implements HasDynoOut<"vec4"> -{ - constructor({ rgba8 }: { rgba8?: DynoVal<"vec4"> }) { - super({ - inTypes: { rgba8: "vec4" }, - inputs: { rgba8 }, - statements: ({ inputs, outputs }) => [ - `target = ${inputs.rgba8 ?? "vec4(0.0, 0.0, 0.0, 0.0)"};`, - ], - }); - } - - dynoOut(): DynoValue<"vec4"> { - return new DynoOutput(this, "rgba8"); - } -} diff --git a/src/dyno/program.ts b/src/dyno/program.ts deleted file mode 100644 index 97c2897..0000000 --- a/src/dyno/program.ts +++ /dev/null @@ -1,117 +0,0 @@ -import * as THREE from "three"; - -import { IDENT_VERTEX_SHADER } from "../utils"; -import { Compilation, type Dyno, type IOTypes } from "./base"; - -export class DynoProgram { - graph: Dyno; - template: DynoProgramTemplate; - inputs: Record; - outputs: Record; - shader: string; - uniforms: Record; - updaters: (() => void)[]; - - constructor({ - graph, - inputs, - outputs, - template, - }: { - graph: Dyno; - inputs?: Record; - outputs?: Record; - template: DynoProgramTemplate; - }) { - this.graph = graph; - this.template = template; - this.inputs = inputs ?? {}; - this.outputs = outputs ?? {}; - - const compile = new Compilation({ indent: this.template.indent }); - for (const key in this.outputs) { - if (this.outputs[key]) { - compile.declares.add(this.outputs[key]); - } - } - const statements = graph.compile({ - inputs: this.inputs, - outputs: this.outputs, - compile, - }); - - this.shader = template.generate({ globals: compile.globals, statements }); - this.uniforms = compile.uniforms; - this.updaters = compile.updaters; - // console.log("*** COMPILED SHADER", this.shader); - // console.log("*** UNIFORMS", this.uniforms); - } - - prepareMaterial(): THREE.RawShaderMaterial { - return getMaterial(this); - } - - update() { - for (const updater of this.updaters) { - updater(); - } - } -} - -export class DynoProgramTemplate { - before: string; - between: string; - after: string; - indent: string; - - constructor(template: string) { - const globals = template.match(/^([ \t]*)\{\{\s*GLOBALS\s*\}\}/m); - const statements = template.match(/^([ \t]*)\{\{\s*STATEMENTS\s*\}\}/m); - if (!globals || !statements) { - throw new Error( - "Template must contain {{ GLOBALS }} and {{ STATEMENTS }}", - ); - } - - this.before = template.substring(0, globals.index); - this.between = template.substring( - (globals.index as number) + globals[0].length, - statements.index, - ); - this.after = template.substring( - (statements.index as number) + statements[0].length, - ); - this.indent = statements[1]; - } - - generate({ - globals, - statements, - }: { globals: Set; statements: string[] }): string { - return ( - this.before + - Array.from(globals).join("\n\n") + - this.between + - statements.map((s) => this.indent + s).join("\n") + - this.after - ); - } -} - -const programMaterial = new Map(); - -function getMaterial(program: DynoProgram): THREE.RawShaderMaterial { - let material = programMaterial.get(program); - if (material) { - return material; - } - - material = new THREE.RawShaderMaterial({ - glslVersion: THREE.GLSL3, - vertexShader: IDENT_VERTEX_SHADER, - fragmentShader: program.shader, - uniforms: program.uniforms, - }); - programMaterial.set(program, material); - return material; -} diff --git a/src/dyno/splats.ts b/src/dyno/splats.ts deleted file mode 100644 index 309a601..0000000 --- a/src/dyno/splats.ts +++ /dev/null @@ -1,594 +0,0 @@ -import { Dyno, UnaryOp, unindent, unindentLines } from "./base"; -import { - DynoOutput, - type DynoVal, - type DynoValue, - type HasDynoOut, -} from "./value"; - -export const Gsplat = { type: "Gsplat" } as { type: "Gsplat" }; -export const TPackedSplats = { type: "PackedSplats" } as { - type: "PackedSplats"; -}; - -export const numPackedSplats = ( - packedSplats: DynoVal, -): DynoVal<"int"> => new NumPackedSplats({ packedSplats }); -export const readPackedSplat = ( - packedSplats: DynoVal, - index: DynoVal<"int">, -): DynoVal => new ReadPackedSplat({ packedSplats, index }); -export const readPackedSplatRange = ( - packedSplats: DynoVal, - index: DynoVal<"int">, - base: DynoVal<"int">, - count: DynoVal<"int">, -): DynoVal => - new ReadPackedSplatRange({ packedSplats, index, base, count }); -export const splitGsplat = (gsplat: DynoVal) => - new SplitGsplat({ gsplat }); -export const combineGsplat = ({ - gsplat, - flags, - index, - center, - scales, - quaternion, - rgba, - rgb, - opacity, - x, - y, - z, - r, - g, - b, -}: { - gsplat?: DynoVal; - flags?: DynoVal<"uint">; - index?: DynoVal<"int">; - center?: DynoVal<"vec3">; - scales?: DynoVal<"vec3">; - quaternion?: DynoVal<"vec4">; - rgba?: DynoVal<"vec4">; - rgb?: DynoVal<"vec3">; - opacity?: DynoVal<"float">; - x?: DynoVal<"float">; - y?: DynoVal<"float">; - z?: DynoVal<"float">; - r?: DynoVal<"float">; - g?: DynoVal<"float">; - b?: DynoVal<"float">; -}): DynoVal => { - return new CombineGsplat({ - gsplat, - flags, - index, - center, - scales, - quaternion, - rgba, - rgb, - opacity, - x, - y, - z, - r, - g, - b, - }); -}; -export const gsplatNormal = (gsplat: DynoVal): DynoVal<"vec3"> => - new GsplatNormal({ gsplat }); - -export const transformGsplat = ( - gsplat: DynoVal, - { - scale, - rotate, - translate, - recolor, - }: { - scale?: DynoVal<"float">; - rotate?: DynoVal<"vec4">; - translate?: DynoVal<"vec3">; - recolor?: DynoVal<"vec4">; - }, -): DynoVal => { - return new TransformGsplat({ gsplat, scale, rotate, translate, recolor }); -}; - -export const defineGsplat = unindent(` - struct Gsplat { - vec3 center; - uint flags; - vec3 scales; - int index; - vec4 quaternion; - vec4 rgba; - }; - const uint GSPLAT_FLAG_ACTIVE = 1u << 0u; - - bool isGsplatActive(uint flags) { - return (flags & GSPLAT_FLAG_ACTIVE) != 0u; - } -`); - -export const definePackedSplats = unindent(` - struct PackedSplats { - usampler2DArray texture; - int numSplats; - vec4 rgbMinMaxLnScaleMinMax; - }; -`); - -export class NumPackedSplats extends UnaryOp< - typeof TPackedSplats, - "int", - "numSplats" -> { - constructor({ - packedSplats, - }: { packedSplats: DynoVal }) { - super({ a: packedSplats, outKey: "numSplats", outTypeFunc: () => "int" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.numSplats} = ${inputs.a}.numSplats;`, - ]; - } -} - -const defineReadPackedSplat = unindent(` - bool readPackedSplat(usampler2DArray texture, int numSplats, vec4 rgbMinMaxLnScaleMinMax, int index, out Gsplat gsplat) { - if ((index >= 0) && (index < numSplats)) { - uvec4 packed = texelFetch(texture, splatTexCoord(index), 0); - unpackSplatEncoding(packed, gsplat.center, gsplat.scales, gsplat.quaternion, gsplat.rgba, rgbMinMaxLnScaleMinMax); - return true; - } else { - return false; - } - } -`); - -export class ReadPackedSplat - extends Dyno< - { packedSplats: typeof TPackedSplats; index: "int" }, - { gsplat: typeof Gsplat } - > - implements HasDynoOut -{ - constructor({ - packedSplats, - index, - }: { packedSplats?: DynoVal; index?: DynoVal<"int"> }) { - super({ - inTypes: { packedSplats: TPackedSplats, index: "int" }, - outTypes: { gsplat: Gsplat }, - inputs: { packedSplats, index }, - globals: () => [defineGsplat, definePackedSplats, defineReadPackedSplat], - statements: ({ inputs, outputs }) => { - const { gsplat } = outputs; - if (!gsplat) { - return []; - } - const { packedSplats, index } = inputs; - let statements: string[]; - if (packedSplats && index) { - statements = unindentLines(` - if (readPackedSplat(${packedSplats}.texture, ${packedSplats}.numSplats, ${packedSplats}.rgbMinMaxLnScaleMinMax, ${index}, ${gsplat})) { - bool zeroSize = all(equal(${gsplat}.scales, vec3(0.0, 0.0, 0.0))); - ${gsplat}.flags = zeroSize ? 0u : GSPLAT_FLAG_ACTIVE; - } else { - ${gsplat}.flags = 0u; - } - `); - } else { - statements = [`${gsplat}.flags = 0u;`]; - } - statements.push(`${gsplat}.index = ${index ?? "0"};`); - return statements; - }, - }); - } - - dynoOut(): DynoValue { - return new DynoOutput(this, "gsplat"); - } -} - -export class ReadPackedSplatRange - extends Dyno< - { - packedSplats: typeof TPackedSplats; - index: "int"; - base: "int"; - count: "int"; - }, - { gsplat: typeof Gsplat } - > - implements HasDynoOut -{ - constructor({ - packedSplats, - index, - base, - count, - }: { - packedSplats?: DynoVal; - index?: DynoVal<"int">; - base?: DynoVal<"int">; - count?: DynoVal<"int">; - }) { - super({ - inTypes: { - packedSplats: TPackedSplats, - index: "int", - base: "int", - count: "int", - }, - outTypes: { gsplat: Gsplat }, - inputs: { packedSplats, index, base, count }, - globals: () => [defineGsplat, definePackedSplats, defineReadPackedSplat], - statements: ({ inputs, outputs }) => { - const { gsplat } = outputs; - if (!gsplat) { - return []; - } - const { packedSplats, index, base, count } = inputs; - let statements: string[]; - if (packedSplats && index && base && count) { - statements = unindentLines(` - ${gsplat}.flags = 0u; - if ((${index} >= ${base}) && (${index} < (${base} + ${count}))) { - if (readPackedSplat(${packedSplats}.texture, ${packedSplats}.numSplats, ${packedSplats}.rgbMinMaxLnScaleMinMax, ${index}, ${gsplat})) { - bool zeroSize = all(equal(${gsplat}.scales, vec3(0.0, 0.0, 0.0))); - ${gsplat}.flags = zeroSize ? 0u : GSPLAT_FLAG_ACTIVE; - } - } - `); - } else { - statements = [`${gsplat}.flags = 0u;`]; - } - statements.push(`${gsplat}.index = ${index ?? "0"};`); - return statements; - }, - }); - } - - dynoOut(): DynoValue { - return new DynoOutput(this, "gsplat"); - } -} - -export class SplitGsplat extends Dyno< - { gsplat: typeof Gsplat }, - { - flags: "uint"; - active: "bool"; - index: "int"; - center: "vec3"; - scales: "vec3"; - quaternion: "vec4"; - rgba: "vec4"; - rgb: "vec3"; - opacity: "float"; - x: "float"; - y: "float"; - z: "float"; - r: "float"; - g: "float"; - b: "float"; - } -> { - constructor({ gsplat }: { gsplat?: DynoVal }) { - super({ - inTypes: { gsplat: Gsplat }, - outTypes: { - flags: "uint", - active: "bool", - index: "int", - center: "vec3", - scales: "vec3", - quaternion: "vec4", - rgba: "vec4", - rgb: "vec3", - opacity: "float", - x: "float", - y: "float", - z: "float", - r: "float", - g: "float", - b: "float", - }, - inputs: { gsplat }, - globals: () => [defineGsplat], - statements: ({ inputs, outputs }) => { - const { gsplat } = inputs; - const { - flags, - active, - index, - center, - scales, - quaternion, - rgba, - rgb, - opacity, - x, - y, - z, - r, - g, - b, - } = outputs; - return [ - !flags ? null : `${flags} = ${gsplat ? `${gsplat}.flags` : "0u"};`, - !active - ? null - : `${active} = isGsplatActive(${gsplat ? `${gsplat}.flags` : "0u"});`, - !index ? null : `${index} = ${gsplat ? `${gsplat}.index` : "0"};`, - !center - ? null - : `${center} = ${gsplat ? `${gsplat}.center` : "vec3(0.0, 0.0, 0.0)"};`, - !scales - ? null - : `${scales} = ${gsplat ? `${gsplat}.scales` : "vec3(0.0, 0.0, 0.0)"};`, - !quaternion - ? null - : `${quaternion} = ${gsplat ? `${gsplat}.quaternion` : "vec4(0.0, 0.0, 0.0, 1.0)"};`, - !rgba - ? null - : `${rgba} = ${gsplat ? `${gsplat}.rgba` : "vec4(0.0, 0.0, 0.0, 0.0)"};`, - !rgb - ? null - : `${rgb} = ${gsplat ? `${gsplat}.rgba.rgb` : "vec3(0.0, 0.0, 0.0)"};`, - !opacity - ? null - : `${opacity} = ${gsplat ? `${gsplat}.rgba.a` : "0.0"};`, - !x ? null : `${x} = ${gsplat ? `${gsplat}.center.x` : "0.0"};`, - !y ? null : `${y} = ${gsplat ? `${gsplat}.center.y` : "0.0"};`, - !z ? null : `${z} = ${gsplat ? `${gsplat}.center.z` : "0.0"};`, - !r ? null : `${r} = ${gsplat ? `${gsplat}.rgba.r` : "0.0"};`, - !g ? null : `${g} = ${gsplat ? `${gsplat}.rgba.g` : "0.0"};`, - !b ? null : `${b} = ${gsplat ? `${gsplat}.rgba.b` : "0.0"};`, - ].filter(Boolean) as string[]; - }, - }); - } -} - -export class CombineGsplat - extends Dyno< - { - gsplat: typeof Gsplat; - flags: "uint"; - index: "int"; - center: "vec3"; - scales: "vec3"; - quaternion: "vec4"; - rgba: "vec4"; - rgb: "vec3"; - opacity: "float"; - x: "float"; - y: "float"; - z: "float"; - r: "float"; - g: "float"; - b: "float"; - }, - { gsplat: typeof Gsplat } - > - implements HasDynoOut -{ - constructor({ - gsplat, - flags, - index, - center, - scales, - quaternion, - rgba, - rgb, - opacity, - x, - y, - z, - r, - g, - b, - }: { - gsplat?: DynoVal; - flags?: DynoVal<"uint">; - index?: DynoVal<"int">; - center?: DynoVal<"vec3">; - scales?: DynoVal<"vec3">; - quaternion?: DynoVal<"vec4">; - rgba?: DynoVal<"vec4">; - rgb?: DynoVal<"vec3">; - opacity?: DynoVal<"float">; - x?: DynoVal<"float">; - y?: DynoVal<"float">; - z?: DynoVal<"float">; - r?: DynoVal<"float">; - g?: DynoVal<"float">; - b?: DynoVal<"float">; - }) { - super({ - inTypes: { - gsplat: Gsplat, - flags: "uint", - index: "int", - center: "vec3", - scales: "vec3", - quaternion: "vec4", - rgba: "vec4", - rgb: "vec3", - opacity: "float", - x: "float", - y: "float", - z: "float", - r: "float", - g: "float", - b: "float", - }, - outTypes: { gsplat: Gsplat }, - inputs: { - gsplat, - flags, - index, - center, - scales, - quaternion, - rgba, - rgb, - opacity, - x, - y, - z, - r, - g, - b, - }, - globals: () => [defineGsplat], - statements: ({ inputs, outputs }) => { - const { gsplat: outGsplat } = outputs; - if (!outGsplat) { - return []; - } - const { - gsplat, - flags, - index, - center, - scales, - quaternion, - rgba, - rgb, - opacity, - x, - y, - z, - r, - g, - b, - } = inputs; - return [ - `${outGsplat}.flags = ${flags ?? (gsplat ? `${gsplat}.flags` : "0u")};`, - `${outGsplat}.index = ${index ?? (gsplat ? `${gsplat}.index` : "0")};`, - `${outGsplat}.center = ${center ?? (gsplat ? `${gsplat}.center` : "vec3(0.0, 0.0, 0.0)")};`, - `${outGsplat}.scales = ${scales ?? (gsplat ? `${gsplat}.scales` : "vec3(0.0, 0.0, 0.0)")};`, - `${outGsplat}.quaternion = ${quaternion ?? (gsplat ? `${gsplat}.quaternion` : "vec4(0.0, 0.0, 0.0, 1.0)")};`, - `${outGsplat}.rgba = ${rgba ?? (gsplat ? `${gsplat}.rgba` : "vec4(0.0, 0.0, 0.0, 0.0)")};`, - !rgb ? null : `${outGsplat}.rgba.rgb = ${rgb};`, - !opacity ? null : `${outGsplat}.rgba.a = ${opacity};`, - !x ? null : `${outGsplat}.center.x = ${x};`, - !y ? null : `${outGsplat}.center.y = ${y};`, - !z ? null : `${outGsplat}.center.z = ${z};`, - !r ? null : `${outGsplat}.rgba.r = ${r};`, - !g ? null : `${outGsplat}.rgba.g = ${g};`, - !b ? null : `${outGsplat}.rgba.b = ${b};`, - ].filter(Boolean) as string[]; - }, - }); - } - - dynoOut(): DynoValue { - return new DynoOutput(this, "gsplat"); - } -} - -export const defineGsplatNormal = unindent(` - vec3 gsplatNormal(vec3 scales, vec4 quaternion) { - float minScale = min(scales.x, min(scales.y, scales.z)); - vec3 normal; - if (scales.z == minScale) { - normal = vec3(0.0, 0.0, 1.0); - } else if (scales.y == minScale) { - normal = vec3(0.0, 1.0, 0.0); - } else { - normal = vec3(1.0, 0.0, 0.0); - } - return quatVec(quaternion, normal); - } -`); - -export class GsplatNormal extends UnaryOp { - constructor({ gsplat }: { gsplat: DynoVal }) { - super({ a: gsplat, outKey: "normal", outTypeFunc: () => "vec3" }); - this.globals = () => [defineGsplat, defineGsplatNormal]; - this.statements = ({ inputs, outputs }) => [ - `${outputs.normal} = gsplatNormal(${inputs.a}.scales, ${inputs.a}.quaternion);`, - ]; - } -} - -export class TransformGsplat - extends Dyno< - { - gsplat: typeof Gsplat; - scale: "float"; - rotate: "vec4"; - translate: "vec3"; - recolor: "vec4"; - }, - { gsplat: typeof Gsplat } - > - implements HasDynoOut -{ - constructor({ - gsplat, - scale, - rotate, - translate, - recolor, - }: { - gsplat?: DynoVal; - scale?: DynoVal<"float">; - rotate?: DynoVal<"vec4">; - translate?: DynoVal<"vec3">; - recolor?: DynoVal<"vec4">; - }) { - super({ - inTypes: { - gsplat: Gsplat, - scale: "float", - rotate: "vec4", - translate: "vec3", - recolor: "vec4", - }, - outTypes: { gsplat: Gsplat }, - inputs: { gsplat, scale, rotate, translate, recolor }, - globals: () => [defineGsplat], - statements: ({ inputs, outputs, compile }) => { - const { gsplat } = outputs; - if (!gsplat || !inputs.gsplat) { - return []; - } - const { scale, rotate, translate, recolor } = inputs; - const indent = compile.indent; - const statements = [ - `${gsplat} = ${inputs.gsplat};`, - `if (isGsplatActive(${gsplat}.flags)) {`, - - scale ? `${indent}${gsplat}.center *= ${scale};` : null, - rotate - ? `${indent}${gsplat}.center = quatVec(${rotate}, ${gsplat}.center);` - : null, - translate ? `${indent}${gsplat}.center += ${translate};` : null, - - scale ? `${indent}${gsplat}.scales *= ${scale};` : null, - - rotate - ? `${indent}${gsplat}.quaternion = quatQuat(${rotate}, ${gsplat}.quaternion);` - : null, - recolor ? `${indent}${gsplat}.rgba *= ${recolor};` : null, - "}", - ].filter(Boolean) as string[]; - return statements; - }, - }); - } - - dynoOut(): DynoValue { - return new DynoOutput(this, "gsplat"); - } -} diff --git a/src/dyno/texture.ts b/src/dyno/texture.ts deleted file mode 100644 index 92b924a..0000000 --- a/src/dyno/texture.ts +++ /dev/null @@ -1,239 +0,0 @@ -import { Dyno } from "./base"; -import type { - AllSamplerTypes, - IsamplerTypes, - NormalSamplerTypes, - Sampler2DArrayTypes, - Sampler2DTypes, - Sampler3DTypes, - SamplerCubeTypes, - SamplerShadowTypes, - SamplerTypes, - UsamplerTypes, -} from "./types"; -import { - DynoOutput, - type DynoVal, - type DynoValue, - type HasDynoOut, - valType, -} from "./value"; - -export const textureSize = ( - texture: DynoVal, - lod?: DynoVal<"int">, -): DynoVal> => new TextureSize({ texture, lod }); -export const texture = ( - texture: DynoVal, - coord: DynoVal>, - bias?: DynoVal<"float">, -): DynoVal> => new Texture({ texture, coord, bias }); -export const texelFetch = ( - texture: DynoVal, - coord: DynoVal>, - lod?: DynoVal<"int">, -): DynoVal> => new TexelFetch({ texture, coord, lod }); - -export class TextureSize - extends Dyno<{ texture: T; lod: "int" }, { size: TextureSizeType }> - implements HasDynoOut> -{ - constructor({ texture, lod }: { texture: DynoVal; lod?: DynoVal<"int"> }) { - const textureType = valType(texture); - super({ - inTypes: { texture: textureType, lod: "int" }, - outTypes: { size: textureSizeType(textureType) }, - inputs: { texture, lod }, - statements: ({ inputs, outputs }) => [ - `${outputs.size} = textureSize(${inputs.texture}, ${inputs.lod ?? "0"});`, - ], - }); - } - - dynoOut(): DynoValue> { - return new DynoOutput(this, "size"); - } -} - -export class Texture - extends Dyno< - { texture: T; coord: TextureCoordType; bias: "float" }, - { sample: TextureReturnType } - > - implements HasDynoOut> -{ - constructor({ - texture, - coord, - bias, - }: { - texture: DynoVal; - coord: DynoVal>; - bias?: DynoVal<"float">; - }) { - const textureType = valType(texture); - super({ - inTypes: { - texture: textureType, - coord: textureCoordType(textureType), - bias: "float", - }, - outTypes: { sample: textureReturnType(textureType) }, - inputs: { texture, coord, bias }, - statements: ({ inputs, outputs }) => [ - `${outputs.sample} = texture(${inputs.texture}, ${inputs.coord}${inputs.bias ? `, ${inputs.bias}` : ""});`, - ], - }); - } - - dynoOut(): DynoValue> { - return new DynoOutput(this, "sample"); - } -} - -export class TexelFetch - extends Dyno< - { texture: T; coord: TextureSizeType; lod: "int" }, - { texel: TextureReturnType } - > - implements HasDynoOut> -{ - constructor({ - texture, - coord, - lod, - }: { - texture: DynoVal; - coord: DynoVal>; - lod?: DynoVal<"int">; - }) { - const textureType = valType(texture); - super({ - inTypes: { - texture: textureType, - coord: textureSizeType(textureType), - lod: "int", - }, - outTypes: { texel: textureReturnType(textureType) }, - inputs: { texture, coord, lod }, - statements: ({ inputs, outputs }) => [ - `${outputs.texel} = texelFetch(${inputs.texture}, ${inputs.coord}, ${inputs.lod ?? "0"});`, - ], - }); - } - - dynoOut(): DynoValue> { - return new DynoOutput(this, "texel"); - } -} - -type TextureSizeType = T extends - | Sampler2DTypes - | SamplerCubeTypes - ? "ivec2" - : T extends Sampler3DTypes | Sampler2DArrayTypes - ? "ivec3" - : never; - -function textureSizeType( - textureType: T, -): TextureSizeType { - switch (textureType) { - case "sampler2D": - case "usampler2D": - case "isampler2D": - case "samplerCube": - case "usamplerCube": - case "isamplerCube": - case "sampler2DShadow": - case "samplerCubeShadow": - return "ivec2" as TextureSizeType; - case "sampler3D": - case "usampler3D": - case "isampler3D": - case "sampler2DArray": - case "usampler2DArray": - case "isampler2DArray": - case "sampler2DArrayShadow": - return "ivec3" as TextureSizeType; - default: - throw new Error(`Invalid texture type: ${textureType}`); - } -} - -type TextureCoordType = T extends Sampler2DTypes - ? "vec2" - : T extends - | Sampler3DTypes - | Sampler2DArrayTypes - | SamplerCubeTypes - | Sampler2DArrayTypes - ? "vec3" - : T extends "samperCubeShadow" | "sampler2DArrayShadow" - ? "vec4" - : never; - -function textureCoordType( - textureType: T, -): TextureCoordType { - switch (textureType) { - case "sampler2D": - case "usampler2D": - case "isampler2D": - return "vec2" as TextureCoordType; - case "sampler3D": - case "usampler3D": - case "isampler3D": - case "samplerCube": - case "usamplerCube": - case "isamplerCube": - case "sampler2DArray": - case "usampler2DArray": - case "isampler2DArray": - case "sampler2DShadow": - return "vec3" as TextureCoordType; - case "samplerCubeShadow": - case "sampler2DArrayShadow": - return "vec4" as TextureCoordType; - default: - throw new Error(`Invalid texture type: ${textureType}`); - } -} - -type TextureReturnType = T extends SamplerTypes - ? "vec4" - : T extends UsamplerTypes - ? "uvec4" - : T extends IsamplerTypes - ? "ivec4" - : T extends SamplerShadowTypes - ? "float" - : never; - -function textureReturnType( - textureType: T, -): TextureReturnType { - switch (textureType) { - case "sampler2D": - case "sampler2DArray": - case "sampler3D": - case "samplerCube": - case "sampler2DShadow": - return "vec4" as TextureReturnType; - case "usampler2D": - case "usampler2DArray": - case "usampler3D": - case "usamplerCube": - return "uvec4" as TextureReturnType; - case "isampler2D": - case "isampler2DArray": - case "isampler3D": - case "isamplerCube": - return "ivec4" as TextureReturnType; - case "samplerCubeShadow": - case "sampler2DArrayShadow": - return "float" as TextureReturnType; - default: - throw new Error(`Invalid texture type: ${textureType}`); - } -} diff --git a/src/dyno/transform.ts b/src/dyno/transform.ts deleted file mode 100644 index 15146df..0000000 --- a/src/dyno/transform.ts +++ /dev/null @@ -1,155 +0,0 @@ -import { Dyno } from "./base"; -import type { DynoVal } from "./value"; - -export const transformPos = ( - position: DynoVal<"vec3">, - { - scale, - scales, - rotate, - translate, - }: { - scale?: DynoVal<"float">; - scales?: DynoVal<"vec3">; - rotate?: DynoVal<"vec4">; - translate?: DynoVal<"vec3">; - }, -): DynoVal<"vec3"> => { - return new TransformPosition({ position, scale, scales, rotate, translate }) - .outputs.position; -}; -export const transformDir = ( - dir: DynoVal<"vec3">, - { - scale, - scales, - rotate, - }: { - scale?: DynoVal<"float">; - scales?: DynoVal<"vec3">; - rotate?: DynoVal<"vec4">; - }, -): DynoVal<"vec3"> => { - return new TransformDir({ dir, scale, scales, rotate }).outputs.dir; -}; -export const transformQuat = ( - quaternion: DynoVal<"vec4">, - { rotate }: { rotate?: DynoVal<"vec4"> }, -): DynoVal<"vec4"> => { - return new TransformQuaternion({ quaternion, rotate }).outputs.quaternion; -}; - -export class TransformPosition extends Dyno< - { - position: "vec3"; - scale: "float"; - scales: "vec3"; - rotate: "vec4"; - translate: "vec3"; - }, - { position: "vec3" } -> { - constructor({ - position, - scale, - scales, - rotate, - translate, - }: { - position?: DynoVal<"vec3">; - scale?: DynoVal<"float">; - scales?: DynoVal<"vec3">; - rotate?: DynoVal<"vec4">; - translate?: DynoVal<"vec3">; - }) { - super({ - inTypes: { - position: "vec3", - scale: "float", - scales: "vec3", - rotate: "vec4", - translate: "vec3", - }, - outTypes: { position: "vec3" }, - inputs: { position, scale, scales, rotate, translate }, - statements: ({ inputs, outputs }) => { - const { position } = outputs; - if (!position) { - return []; - } - const { scale, scales, rotate, translate } = inputs; - return [ - `${position} = ${inputs.position ?? "vec3(0.0, 0.0, 0.0)"};`, - !scale ? null : `${position} *= ${scale};`, - !scales ? null : `${position} *= ${scales};`, - !rotate ? null : `${position} = quatVec(${rotate}, ${position});`, - !translate ? null : `${position} += ${translate};`, - ].filter(Boolean) as string[]; - }, - }); - } -} - -export class TransformDir extends Dyno< - { dir: "vec3"; scale: "float"; scales: "vec3"; rotate: "vec4" }, - { dir: "vec3" } -> { - constructor({ - dir, - scale, - scales, - rotate, - }: { - dir?: DynoVal<"vec3">; - scale?: DynoVal<"float">; - scales?: DynoVal<"vec3">; - rotate?: DynoVal<"vec4">; - }) { - super({ - inTypes: { dir: "vec3", scale: "float", scales: "vec3", rotate: "vec4" }, - outTypes: { dir: "vec3" }, - inputs: { dir, scale, scales, rotate }, - statements: ({ inputs, outputs }) => { - const { dir } = outputs; - if (!dir) { - return []; - } - const { scale, scales, rotate } = inputs; - return [ - `${dir} = ${inputs.dir ?? "vec3(0.0, 0.0, 0.0)"};`, - !scale ? null : `${dir} *= ${scale};`, - !scales ? null : `${dir} *= ${scales};`, - !rotate ? null : `${dir} = quatVec(${rotate}, ${dir});`, - ].filter(Boolean) as string[]; - }, - }); - } -} - -export class TransformQuaternion extends Dyno< - { quaternion: "vec4"; rotate: "vec4" }, - { quaternion: "vec4" } -> { - constructor({ - quaternion, - rotate, - }: { quaternion?: DynoVal<"vec4">; rotate?: DynoVal<"vec4"> }) { - super({ - inTypes: { quaternion: "vec4", rotate: "vec4" }, - outTypes: { quaternion: "vec4" }, - inputs: { quaternion, rotate }, - statements: ({ inputs, outputs }) => { - const { quaternion } = outputs; - if (!quaternion) { - return []; - } - return [ - `${quaternion} = ${inputs.quaternion ?? "vec4(0.0, 0.0, 0.0, 1.0)"};`, - !rotate - ? null - : `${quaternion} = quatQuat(${inputs.rotate}, ${quaternion});`, - ].filter(Boolean) as string[]; - }, - }); - } -} diff --git a/src/dyno/trig.ts b/src/dyno/trig.ts deleted file mode 100644 index ce9c008..0000000 --- a/src/dyno/trig.ts +++ /dev/null @@ -1,182 +0,0 @@ -import { BinaryOp, UnaryOp } from "./base"; -import type { FloatTypes } from "./types"; -import type { DynoVal } from "./value"; - -export const radians = ( - degrees: DynoVal, -): DynoVal => new Radians({ degrees }); -export const degrees = ( - radians: DynoVal, -): DynoVal => new Degrees({ radians }); - -export const sin = (radians: DynoVal): DynoVal => - new Sin({ radians }); -export const cos = (radians: DynoVal): DynoVal => - new Cos({ radians }); -export const tan = (radians: DynoVal): DynoVal => - new Tan({ radians }); - -export const asin = (sin: DynoVal): DynoVal => - new Asin({ sin }); -export const acos = (cos: DynoVal): DynoVal => - new Acos({ cos }); -export const atan = (tan: DynoVal): DynoVal => - new Atan({ tan }); -export const atan2 = ( - y: DynoVal, - x: DynoVal, -): DynoVal => new Atan2({ y, x }); - -export const sinh = (x: DynoVal): DynoVal => - new Sinh({ x }); -export const cosh = (x: DynoVal): DynoVal => - new Cosh({ x }); -export const tanh = (x: DynoVal): DynoVal => - new Tanh({ x }); - -export const asinh = (x: DynoVal): DynoVal => - new Asinh({ x }); -export const acosh = (x: DynoVal): DynoVal => - new Acosh({ x }); -export const atanh = (x: DynoVal): DynoVal => - new Atanh({ x }); - -export class Radians extends UnaryOp { - constructor({ degrees }: { degrees: DynoVal }) { - super({ a: degrees, outTypeFunc: (aType) => aType, outKey: "radians" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.radians} = radians(${inputs.a});`, - ]; - } -} - -export class Degrees extends UnaryOp { - constructor({ radians }: { radians: DynoVal }) { - super({ a: radians, outTypeFunc: (aType) => aType, outKey: "degrees" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.degrees} = degrees(${inputs.a});`, - ]; - } -} - -export class Sin extends UnaryOp { - constructor({ radians }: { radians: DynoVal }) { - super({ a: radians, outTypeFunc: (aType) => aType, outKey: "sin" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.sin} = sin(${inputs.a});`, - ]; - } -} - -export class Cos extends UnaryOp { - constructor({ radians }: { radians: DynoVal }) { - super({ a: radians, outTypeFunc: (aType) => aType, outKey: "cos" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.cos} = cos(${inputs.a});`, - ]; - } -} - -export class Tan extends UnaryOp { - constructor({ radians }: { radians: DynoVal }) { - super({ a: radians, outTypeFunc: (aType) => aType, outKey: "tan" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.tan} = tan(${inputs.a});`, - ]; - } -} - -export class Asin extends UnaryOp { - constructor({ sin }: { sin: DynoVal }) { - super({ a: sin, outTypeFunc: (aType) => aType, outKey: "asin" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.asin} = asin(${inputs.a});`, - ]; - } -} - -export class Acos extends UnaryOp { - constructor({ cos }: { cos: DynoVal }) { - super({ a: cos, outTypeFunc: (aType) => aType, outKey: "acos" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.acos} = acos(${inputs.a});`, - ]; - } -} - -export class Atan extends UnaryOp { - constructor({ tan }: { tan: DynoVal }) { - super({ a: tan, outTypeFunc: (aType) => aType, outKey: "atan" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.atan} = atan(${inputs.a});`, - ]; - } -} - -export class Atan2 extends BinaryOp { - constructor({ y, x }: { y: DynoVal; x: DynoVal }) { - super({ - a: y, - b: x, - outTypeFunc: (aType, bType) => aType, - outKey: "atan2", - }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.atan2} = atan2(${inputs.a}, ${inputs.b});`, - ]; - } -} - -export class Sinh extends UnaryOp { - constructor({ x }: { x: DynoVal }) { - super({ a: x, outTypeFunc: (aType) => aType, outKey: "sinh" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.sinh} = sinh(${inputs.a});`, - ]; - } -} - -export class Cosh extends UnaryOp { - constructor({ x }: { x: DynoVal }) { - super({ a: x, outTypeFunc: (aType) => aType, outKey: "cosh" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.cosh} = cosh(${inputs.a});`, - ]; - } -} - -export class Tanh extends UnaryOp { - constructor({ x }: { x: DynoVal }) { - super({ a: x, outTypeFunc: (aType) => aType, outKey: "tanh" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.tanh} = tanh(${inputs.a});`, - ]; - } -} - -export class Asinh extends UnaryOp { - constructor({ x }: { x: DynoVal }) { - super({ a: x, outTypeFunc: (aType) => aType, outKey: "asinh" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.asinh} = asinh(${inputs.a});`, - ]; - } -} - -export class Acosh extends UnaryOp { - constructor({ x }: { x: DynoVal }) { - super({ a: x, outTypeFunc: (aType) => aType, outKey: "acosh" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.acosh} = acosh(${inputs.a});`, - ]; - } -} - -export class Atanh extends UnaryOp { - constructor({ x }: { x: DynoVal }) { - super({ a: x, outTypeFunc: (aType) => aType, outKey: "atanh" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.atanh} = atanh(${inputs.a});`, - ]; - } -} diff --git a/src/dyno/types.ts b/src/dyno/types.ts deleted file mode 100644 index 2fbad55..0000000 --- a/src/dyno/types.ts +++ /dev/null @@ -1,420 +0,0 @@ -import type * as THREE from "three"; - -export type BoolTypes = "bool" | "bvec2" | "bvec3" | "bvec4"; -export type IntTypes = "int" | "ivec2" | "ivec3" | "ivec4"; -export type UintTypes = "uint" | "uvec2" | "uvec3" | "uvec4"; -export type AllIntTypes = IntTypes | UintTypes; -export type FloatTypes = "float" | "vec2" | "vec3" | "vec4"; -export type ScalarTypes = "uint" | "int" | "float"; -export type Vector2Types = "vec2" | "ivec2" | "uvec2"; -export type Vector3Types = "vec3" | "ivec3" | "uvec3"; -export type Vector4Types = "vec4" | "ivec4" | "uvec4"; -export type VectorTypes = Vector2Types | Vector3Types | Vector4Types; -export type MatFloatTypes = - | "mat2" - | "mat2x2" - | "mat2x3" - | "mat2x4" - | "mat3" - | "mat3x2" - | "mat3x3" - | "mat3x4" - | "mat4" - | "mat4x2" - | "mat4x3" - | "mat4x4"; -export type SquareMatTypes = - | "mat2" - | "mat3" - | "mat4" - | "mat2x2" - | "mat3x3" - | "mat4x4"; -export type AllFloatTypes = FloatTypes | MatFloatTypes; -export type SignedTypes = IntTypes | FloatTypes; -export type AllSignedTypes = SignedTypes | MatFloatTypes; -export type ValueTypes = FloatTypes | IntTypes | UintTypes; -export type AllValueTypes = AllFloatTypes | IntTypes | UintTypes; -export type SimpleTypes = BoolTypes | AllValueTypes; - -export type VectorElementType = A extends FloatTypes - ? "float" - : A extends IntTypes - ? "int" - : A extends UintTypes - ? "uint" - : never; - -export type SameSizeVec = T extends ScalarTypes - ? "float" - : T extends "vec2" | "ivec2" | "uvec2" - ? "vec2" - : T extends "vec3" | "ivec3" | "uvec3" - ? "vec3" - : T extends "vec4" | "ivec4" | "uvec4" - ? "vec4" - : never; - -export type SameSizeUvec = T extends ScalarTypes - ? "uint" - : T extends "vec2" | "ivec2" | "uvec2" - ? "uvec2" - : T extends "vec3" | "ivec3" | "uvec3" - ? "uvec3" - : T extends "vec4" | "ivec4" | "uvec4" - ? "uvec4" - : never; - -export type SameSizeIvec = T extends ScalarTypes - ? "int" - : T extends "vec2" | "ivec2" | "uvec2" - ? "ivec2" - : T extends "vec3" | "ivec3" | "uvec3" - ? "ivec3" - : T extends "vec4" | "ivec4" | "uvec4" - ? "ivec4" - : never; - -export type SamplerTypes = - | "sampler2D" - | "sampler2DArray" - | "sampler3D" - | "samplerCube"; -export type UsamplerTypes = - | "usampler2D" - | "usampler2DArray" - | "usampler3D" - | "usamplerCube"; -export type IsamplerTypes = - | "isampler2D" - | "isampler2DArray" - | "isampler3D" - | "isamplerCube"; -export type NormalSamplerTypes = SamplerTypes | UsamplerTypes | IsamplerTypes; -export type SamplerShadowTypes = - | "sampler2DShadow" - | "sampler2DArrayShadow" - | "samplerCubeShadow"; -export type AllSamplerTypes = NormalSamplerTypes | SamplerShadowTypes; -export type Sampler2DTypes = - | "sampler2D" - | "usampler2D" - | "isampler2D" - | "sampler2DShadow"; -export type Sampler2DArrayTypes = - | "sampler2DArray" - | "usampler2DArray" - | "isampler2DArray" - | "sampler2DArrayShadow"; -export type Sampler3DTypes = "sampler3D" | "usampler3D" | "isampler3D"; -export type SamplerCubeTypes = - | "samplerCube" - | "usamplerCube" - | "isamplerCube" - | "samplerCubeShadow"; - -export function isBoolType(type: DynoType): boolean { - return ( - type === "bool" || type === "bvec2" || type === "bvec3" || type === "bvec4" - ); -} - -export function isScalarType(type: DynoType): boolean { - return type === "int" || type === "uint" || type === "float"; -} - -export function isIntType(type: DynoType): boolean { - return ( - type === "int" || type === "ivec2" || type === "ivec3" || type === "ivec4" - ); -} - -export function isUintType(type: DynoType): boolean { - return ( - type === "uint" || type === "uvec2" || type === "uvec3" || type === "uvec4" - ); -} - -export function isFloatType(type: DynoType): boolean { - return ( - type === "float" || type === "vec2" || type === "vec3" || type === "vec4" - ); -} - -export function isMatFloatType(type: DynoType): boolean { - return ( - type === "mat2" || - type === "mat2x2" || - type === "mat2x3" || - type === "mat2x4" || - type === "mat3" || - type === "mat3x2" || - type === "mat3x3" || - type === "mat3x4" || - type === "mat4" || - type === "mat4x2" || - type === "mat4x3" || - type === "mat4x4" - ); -} - -export function isAllFloatType(type: DynoType): boolean { - return isFloatType(type) || isMatFloatType(type); -} - -export function isVector2Type(type: DynoType): boolean { - return type === "vec2" || type === "ivec2" || type === "uvec2"; -} - -export function isVector3Type(type: DynoType): boolean { - return type === "vec3" || type === "ivec3" || type === "uvec3"; -} - -export function isVector4Type(type: DynoType): boolean { - return type === "vec4" || type === "ivec4" || type === "uvec4"; -} - -export function isVectorType(type: DynoType): boolean { - return isVector2Type(type) || isVector3Type(type) || isVector4Type(type); -} - -export function isMat2(type: DynoType): boolean { - return type === "mat2" || type === "mat2x2"; -} -export function isMat3(type: DynoType): boolean { - return type === "mat3" || type === "mat3x3"; -} -export function isMat4(type: DynoType): boolean { - return type === "mat4" || type === "mat4x4"; -} - -export function vectorElementType( - type: A, -): VectorElementType { - switch (type) { - case "vec2": - return "float" as VectorElementType; - case "vec3": - return "float" as VectorElementType; - case "vec4": - return "float" as VectorElementType; - case "ivec2": - return "int" as VectorElementType; - case "ivec3": - return "int" as VectorElementType; - case "ivec4": - return "int" as VectorElementType; - case "uvec2": - return "uint" as VectorElementType; - case "uvec3": - return "uint" as VectorElementType; - case "uvec4": - return "uint" as VectorElementType; - default: - throw new Error(`Invalid vector type: ${type}`); - } -} - -export function vectorDim(type: A): number { - switch (type) { - case "vec2": - case "ivec2": - case "uvec2": - return 2; - case "vec3": - case "ivec3": - case "uvec3": - return 3; - case "vec4": - case "ivec4": - case "uvec4": - return 4; - default: - throw new Error(`Invalid vector type: ${type}`); - } -} - -export function sameSizeVec(type: T): SameSizeVec { - if (isScalarType(type)) { - return "float" as SameSizeVec; - } - if (isVector2Type(type)) { - return "vec2" as SameSizeVec; - } - if (isVector3Type(type)) { - return "vec3" as SameSizeVec; - } - if (isVector4Type(type)) { - return "vec4" as SameSizeVec; - } - throw new Error(`Invalid vector type: ${type}`); -} - -export function sameSizeUvec(type: T): SameSizeUvec { - if (isScalarType(type)) { - return "uint" as SameSizeUvec; - } - if (isVector2Type(type)) { - return "uvec2" as SameSizeUvec; - } - if (isVector3Type(type)) { - return "uvec3" as SameSizeUvec; - } - if (isVector4Type(type)) { - return "uvec4" as SameSizeUvec; - } - throw new Error(`Invalid vector type: ${type}`); -} - -export function sameSizeIvec(type: T): SameSizeIvec { - if (isScalarType(type)) { - return "int" as SameSizeIvec; - } - if (isVector2Type(type)) { - return "ivec2" as SameSizeIvec; - } - if (isVector3Type(type)) { - return "ivec3" as SameSizeIvec; - } - if (isVector4Type(type)) { - return "ivec4" as SameSizeIvec; - } - throw new Error(`Invalid vector type: ${type}`); -} - -export type BaseType = SimpleTypes | AllSamplerTypes; -export type UserType = { type: string }; -export type DynoType = BaseType | UserType; - -export type DynoJsType = T extends "bool" - ? boolean - : T extends "uint" - ? number - : T extends "int" - ? number - : T extends "float" - ? number - : T extends "bvec2" - ? [boolean, boolean] - : T extends "uvec2" - ? THREE.Vector2 | [number, number] | Uint32Array - : T extends "ivec2" - ? THREE.Vector2 | [number, number] | Int32Array - : T extends "vec2" - ? THREE.Vector2 | [number, number] | Float32Array - : T extends "bvec3" - ? [boolean, boolean, boolean] - : T extends "uvec3" - ? THREE.Vector3 | [number, number, number] | Uint32Array - : T extends "ivec3" - ? THREE.Vector3 | [number, number, number] | Int32Array - : T extends "vec3" - ? - | THREE.Vector3 - | THREE.Color - | [number, number, number] - | Float32Array - : T extends "bvec4" - ? [boolean, boolean, boolean, boolean] - : T extends "uvec4" - ? - | THREE.Vector4 - | [number, number, number, number] - | Uint32Array - : T extends "ivec4" - ? - | THREE.Vector4 - | [number, number, number, number] - | Int32Array - : T extends "vec4" - ? - | THREE.Vector4 - | THREE.Quaternion - | [number, number, number, number] - | Float32Array - : T extends "mat2" - ? THREE.Matrix2 | Float32Array - : T extends "mat2x2" - ? THREE.Matrix2 | Float32Array - : T extends "mat2x3" - ? Float32Array - : T extends "mat2x4" - ? Float32Array - : T extends "mat3" - ? THREE.Matrix3 | Float32Array - : T extends "mat3x2" - ? Float32Array - : T extends "mat3x3" - ? THREE.Matrix3 | Float32Array - : T extends "mat3x4" - ? Float32Array - : T extends "mat4" - ? THREE.Matrix4 | Float32Array - : T extends "mat4x2" - ? Float32Array - : T extends "mat4x3" - ? Float32Array - : T extends "mat4x4" - ? - | THREE.Matrix4 - | Float32Array - : T extends "usampler2D" - ? THREE.Texture - : T extends "isampler2D" - ? THREE.Texture - : T extends "sampler2D" - ? THREE.Texture - : T extends "sampler2DShadow" - ? THREE.Texture - : T extends "usampler2DArray" - ? THREE.DataArrayTexture - : T extends "isampler2DArray" - ? THREE.DataArrayTexture - : T extends "sampler2DArray" - ? THREE.DataArrayTexture - : T extends "sampler2DArrayShadow" - ? THREE.Texture - : T extends "usampler3D" - ? THREE.DataArrayTexture - : T extends "isampler3D" - ? THREE.DataArrayTexture - : T extends "sampler3D" - ? THREE.DataArrayTexture - : T extends "usamplerCube" - ? THREE.DataArrayTexture - : T extends "isamplerCube" - ? THREE.DataArrayTexture - : T extends "samplerCube" - ? THREE.DataArrayTexture - : T extends "samplerCubeShadow" - ? THREE.Texture - : unknown; - -export function typeLiteral(type: DynoType): string { - if (typeof type === "string") { - return type; - } - if (typeof type === "object" && type.type) { - return type.type; - } - throw new Error(`Invalid DynoType: ${String(type)}`); -} - -export function numberAsInt(value: number): string { - return Math.trunc(value).toString(); -} - -export function numberAsUint(value: number): string { - const v = Math.max(0, Math.trunc(value)); - return `${v.toString()}u`; -} - -export function numberAsFloat(value: number): string { - return value === Number.POSITIVE_INFINITY - ? "INFINITY" - : value === Number.NEGATIVE_INFINITY - ? "-INFINITY" - : Number.isInteger(value) - ? value.toFixed(1) - : value.toString(); -} diff --git a/src/dyno/uniforms.ts b/src/dyno/uniforms.ts deleted file mode 100644 index 136591f..0000000 --- a/src/dyno/uniforms.ts +++ /dev/null @@ -1,826 +0,0 @@ -import type { IUniform } from "three"; -import { Dyno, dynoDeclare } from "./base"; -import type { DynoJsType, DynoType } from "./types"; -import { DynoOutput, type DynoValue, type HasDynoOut } from "./value"; - -export const uniform = >( - key: string, - type: DynoType, - value: V, -) => new DynoUniform({ key, type, value }); -export const dynoBool = (value = false, key?: string) => - new DynoBool({ key, value }); -export const dynoUint = (value = 0, key?: string) => - new DynoUint({ key, value }); -export const dynoInt = (value = 0, key?: string) => new DynoInt({ key, value }); -export const dynoFloat = (value = 0.0, key?: string) => - new DynoFloat({ key, value }); - -export const dynoBvec2 = >( - value: V, - key?: string, -) => new DynoBvec2({ key, value }); -export const dynoUvec2 = >( - value: V, - key?: string, -) => new DynoUvec2({ key, value }); -export const dynoIvec2 = >( - value: V, - key?: string, -) => new DynoIvec2({ key, value }); -export const dynoVec2 = >( - value: V, - key?: string, -) => new DynoVec2({ key, value }); - -export const dynoBvec3 = >( - value: V, - key?: string, -) => new DynoBvec3({ key, value }); -export const dynoUvec3 = >( - value: V, - key?: string, -) => new DynoUvec3({ key, value }); -export const dynoIvec3 = >( - value: V, - key?: string, -) => new DynoIvec3({ key, value }); -export const dynoVec3 = >( - value: V, - key?: string, -) => new DynoVec3({ key, value }); - -export const dynoBvec4 = >( - value: V, - key?: string, -) => new DynoBvec4({ key, value }); -export const dynoUvec4 = >( - value: V, - key?: string, -) => new DynoUvec4({ key, value }); -export const dynoIvec4 = >( - value: V, - key?: string, -) => new DynoIvec4({ key, value }); -export const dynoVec4 = >( - value: V, - key?: string, -) => new DynoVec4({ key, value }); - -export const dynoMat2 = >( - value: V, - key?: string, -) => new DynoMat2({ key, value }); -export const dynoMat2x2 = >( - value: V, - key?: string, -) => new DynoMat2x2({ key, value }); -export const dynoMat2x3 = >( - value: V, - key?: string, -) => new DynoMat2x3({ key, value }); -export const dynoMat2x4 = >( - value: V, - key?: string, -) => new DynoMat2x4({ key, value }); - -export const dynoMat3 = >( - value: V, - key?: string, -) => new DynoMat3({ key, value }); -export const dynoMat3x2 = >( - value: V, - key?: string, -) => new DynoMat3x2({ key, value }); -export const dynoMat3x3 = >( - value: V, - key?: string, -) => new DynoMat3x3({ key, value }); -export const dynoMat3x4 = >( - value: V, - key?: string, -) => new DynoMat3x4({ key, value }); - -export const dynoMat4 = >( - value: V, - key?: string, -) => new DynoMat4({ key, value }); -export const dynoMat4x2 = >( - value: V, - key?: string, -) => new DynoMat4x2({ key, value }); -export const dynoMat4x3 = >( - value: V, - key?: string, -) => new DynoMat4x3({ key, value }); -export const dynoMat4x4 = >( - value: V, - key?: string, -) => new DynoMat4x4({ key, value }); - -export const dynoUsampler2D = >( - value: V, - key?: string, -) => new DynoUsampler2D({ key, value }); -export const dynoIsampler2D = >( - value: V, - key?: string, -) => new DynoIsampler2D({ key, value }); -export const dynoSampler2D = >( - value: V, - key?: string, -) => new DynoSampler2D({ key, value }); - -export const dynoUsampler2DArray = >( - value: V, - key?: string, -) => new DynoUsampler2DArray({ key, value }); -export const dynoIsampler2DArray = >( - key: string, - value: V, -) => new DynoIsampler2DArray({ key, value }); -export const dynoSampler2DArray = >( - value: V, - key?: string, -) => new DynoSampler2DArray({ key, value }); - -export const dynoUsampler3D = >( - value: V, - key?: string, -) => new DynoUsampler3D({ key, value }); -export const dynoIsampler3D = >( - value: V, - key?: string, -) => new DynoIsampler3D({ key, value }); -export const dynoSampler3D = >( - value: V, - key?: string, -) => new DynoSampler3D({ key, value }); - -export const dynoUsamplerCube = >( - value: V, - key?: string, -) => new DynoUsamplerCube({ key, value }); -export const dynoIsamplerCube = >( - value: V, - key?: string, -) => new DynoIsamplerCube({ key, value }); -export const dynoSamplerCube = >( - value: V, - key?: string, -) => new DynoSamplerCube({ key, value }); - -export const dynoSampler2DShadow = >( - value: V, - key?: string, -) => new DynoSampler2DShadow({ key, value }); -export const dynoSampler2DArrayShadow = < - V extends DynoJsType<"sampler2DArrayShadow">, ->( - value: V, - key?: string, -) => new DynoSampler2DArrayShadow({ key, value }); -export const dynoSamplerCubeShadow = < - V extends DynoJsType<"samplerCubeShadow">, ->( - value: V, - key?: string, -) => new DynoSamplerCubeShadow({ key, value }); - -export class DynoUniform< - T extends DynoType, - K extends string = "value", - V extends DynoJsType = DynoJsType, - > - extends Dyno, { [key in K]: T }> - implements HasDynoOut -{ - public type: T; - public count?: number; - public outKey: K; - public value: V; - public uniform: { value: V; type?: string }; - - constructor({ - key, - type, - count, - value, - update, - globals, - }: { - key?: K; - type: T; - count?: number; - value: V; - update?: (value: V) => V | undefined; - globals?: ({ - inputs, - outputs, - }: { inputs: unknown; outputs: { [key in K]?: string } }) => string[]; - }) { - key = (key ?? "value") as K; - super({ - outTypes: { [key]: type } as { [key in K]: T }, - update: () => { - if (update) { - const value = update(this.value); - if (value !== undefined) { - this.value = value; - } - } - this.uniform.value = this.value; - }, - generate: ({ inputs, outputs }) => { - const allGlobals = globals?.({ inputs, outputs }) ?? []; - const uniforms: Record = {}; - const name = outputs[key]; - if (name) { - allGlobals.push(`uniform ${dynoDeclare(name, type, count)};`); - uniforms[name] = this.uniform; - } - return { globals: allGlobals, uniforms }; - }, - }); - this.type = type; - this.count = count; - this.value = value; - this.uniform = { value }; - this.outKey = key; - } - - dynoOut(): DynoValue { - return new DynoOutput(this, this.outKey); - } -} - -export class DynoBool extends DynoUniform< - "bool", - K, - boolean -> { - constructor({ - key, - value, - update, - }: { - key?: K; - value: boolean; - update?: (value: boolean) => boolean | undefined; - }) { - super({ key, type: "bool", value, update }); - } -} - -export class DynoUint extends DynoUniform<"uint", K, number> { - constructor({ - key, - value, - update, - }: { - key?: K; - value: number; - update?: (value: number) => number | undefined; - }) { - super({ key, type: "uint", value, update }); - } -} - -export class DynoInt extends DynoUniform<"int", K, number> { - constructor({ - key, - value, - update, - }: { - key?: K; - value: number; - update?: (value: number) => number | undefined; - }) { - super({ key, type: "int", value, update }); - } -} - -export class DynoFloat extends DynoUniform< - "float", - K, - number -> { - constructor({ - key, - value, - update, - }: { - key?: K; - value: number; - update?: (value: number) => number | undefined; - }) { - super({ key, type: "float", value, update }); - } -} - -export class DynoBvec2< - K extends string, - V extends DynoJsType<"bvec2">, -> extends DynoUniform<"bvec2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "bvec2", value, update }); - } -} - -export class DynoUvec2< - K extends string, - V extends DynoJsType<"uvec2">, -> extends DynoUniform<"uvec2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "uvec2", value, update }); - } -} - -export class DynoIvec2< - K extends string, - V extends DynoJsType<"ivec2">, -> extends DynoUniform<"ivec2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "ivec2", value, update }); - } -} - -export class DynoVec2< - K extends string, - V extends DynoJsType<"vec2">, -> extends DynoUniform<"vec2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "vec2", value, update }); - } -} - -export class DynoBvec3< - K extends string, - V extends DynoJsType<"bvec3">, -> extends DynoUniform<"bvec3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "bvec3", value, update }); - } -} - -export class DynoUvec3< - V extends DynoJsType<"uvec3">, - K extends string = "value", -> extends DynoUniform<"uvec3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "uvec3", value, update }); - } -} - -export class DynoIvec3< - V extends DynoJsType<"ivec3">, - K extends string = "value", -> extends DynoUniform<"ivec3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "ivec3", value, update }); - } -} - -export class DynoVec3< - V extends DynoJsType<"vec3">, - K extends string = "value", -> extends DynoUniform<"vec3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "vec3", value, update }); - } -} - -export class DynoBvec4< - K extends string, - V extends DynoJsType<"bvec4">, -> extends DynoUniform<"bvec4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "bvec4", value, update }); - } -} - -export class DynoUvec4< - K extends string, - V extends DynoJsType<"uvec4">, -> extends DynoUniform<"uvec4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "uvec4", value, update }); - } -} - -export class DynoIvec4< - K extends string, - V extends DynoJsType<"ivec4">, -> extends DynoUniform<"ivec4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "ivec4", value, update }); - } -} - -export class DynoVec4< - V extends DynoJsType<"vec4">, - K extends string = "value", -> extends DynoUniform<"vec4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "vec4", value, update }); - } -} - -export class DynoMat2< - K extends string, - V extends DynoJsType<"mat2">, -> extends DynoUniform<"mat2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat2", value, update }); - } -} - -export class DynoMat2x2< - K extends string, - V extends DynoJsType<"mat2x2">, -> extends DynoUniform<"mat2x2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat2x2", value, update }); - } -} - -export class DynoMat2x3< - K extends string, - V extends DynoJsType<"mat2x3">, -> extends DynoUniform<"mat2x3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat2x3", value, update }); - } -} - -export class DynoMat2x4< - K extends string, - V extends DynoJsType<"mat2x4">, -> extends DynoUniform<"mat2x4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat2x4", value, update }); - } -} - -export class DynoMat3< - K extends string, - V extends DynoJsType<"mat3">, -> extends DynoUniform<"mat3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat3", value, update }); - } -} - -export class DynoMat3x2< - K extends string, - V extends DynoJsType<"mat3x2">, -> extends DynoUniform<"mat3x2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat3x2", value, update }); - } -} - -export class DynoMat3x3< - K extends string, - V extends DynoJsType<"mat3x3">, -> extends DynoUniform<"mat3x3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat3x3", value, update }); - } -} - -export class DynoMat3x4< - K extends string, - V extends DynoJsType<"mat3x4">, -> extends DynoUniform<"mat3x4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat3x4", value, update }); - } -} - -export class DynoMat4< - K extends string, - V extends DynoJsType<"mat4">, -> extends DynoUniform<"mat4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat4", value, update }); - } -} - -export class DynoMat4x2< - K extends string, - V extends DynoJsType<"mat4x2">, -> extends DynoUniform<"mat4x2", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat4x2", value, update }); - } -} - -export class DynoMat4x3< - K extends string, - V extends DynoJsType<"mat4x3">, -> extends DynoUniform<"mat4x3", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat4x3", value, update }); - } -} - -export class DynoMat4x4< - K extends string, - V extends DynoJsType<"mat4x4">, -> extends DynoUniform<"mat4x4", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "mat4x4", value, update }); - } -} - -export class DynoUsampler2D< - K extends string, - V extends DynoJsType<"usampler2D">, -> extends DynoUniform<"usampler2D", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "usampler2D", value, update }); - } -} - -export class DynoIsampler2D< - K extends string, - V extends DynoJsType<"isampler2D">, -> extends DynoUniform<"isampler2D", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "isampler2D", value, update }); - } -} - -export class DynoSampler2D< - K extends string, - V extends DynoJsType<"sampler2D">, -> extends DynoUniform<"sampler2D", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "sampler2D", value, update }); - } -} - -export class DynoUsampler2DArray< - K extends string, - V extends DynoJsType<"usampler2DArray">, -> extends DynoUniform<"usampler2DArray", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "usampler2DArray", value, update }); - } -} - -export class DynoIsampler2DArray< - K extends string, - V extends DynoJsType<"isampler2DArray">, -> extends DynoUniform<"isampler2DArray", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "isampler2DArray", value, update }); - } -} - -export class DynoSampler2DArray< - K extends string, - V extends DynoJsType<"sampler2DArray">, -> extends DynoUniform<"sampler2DArray", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "sampler2DArray", value, update }); - } -} - -export class DynoUsampler3D< - K extends string, - V extends DynoJsType<"usampler3D">, -> extends DynoUniform<"usampler3D", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "usampler3D", value, update }); - } -} - -export class DynoIsampler3D< - K extends string, - V extends DynoJsType<"isampler3D">, -> extends DynoUniform<"isampler3D", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "isampler3D", value, update }); - } -} - -export class DynoSampler3D< - K extends string, - V extends DynoJsType<"sampler3D">, -> extends DynoUniform<"sampler3D", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "sampler3D", value, update }); - } -} - -export class DynoUsamplerCube< - K extends string, - V extends DynoJsType<"usamplerCube">, -> extends DynoUniform<"usamplerCube", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "usamplerCube", value, update }); - } -} - -export class DynoIsamplerCube< - K extends string, - V extends DynoJsType<"isamplerCube">, -> extends DynoUniform<"isamplerCube", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "isamplerCube", value, update }); - } -} - -export class DynoSamplerCube< - K extends string, - V extends DynoJsType<"samplerCube">, -> extends DynoUniform<"samplerCube", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "samplerCube", value, update }); - } -} - -export class DynoSampler2DShadow< - K extends string, - V extends DynoJsType<"sampler2DShadow">, -> extends DynoUniform<"sampler2DShadow", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "sampler2DShadow", value, update }); - } -} - -export class DynoSampler2DArrayShadow< - K extends string, - V extends DynoJsType<"sampler2DArrayShadow">, -> extends DynoUniform<"sampler2DArrayShadow", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "sampler2DArrayShadow", value, update }); - } -} - -export class DynoSamplerCubeShadow< - K extends string, - V extends DynoJsType<"samplerCubeShadow">, -> extends DynoUniform<"samplerCubeShadow", K, V> { - constructor({ - key, - value, - update, - }: { key?: K; value: V; update?: (value: V) => V | undefined }) { - super({ key, type: "samplerCubeShadow", value, update }); - } -} diff --git a/src/dyno/util.ts b/src/dyno/util.ts deleted file mode 100644 index df1c778..0000000 --- a/src/dyno/util.ts +++ /dev/null @@ -1,441 +0,0 @@ -import { Dyno, DynoBlock, unindent } from "./base"; -import { float, vec2, vec3, vec4 } from "./convert"; -import { mul } from "./math"; -import { type ValueTypes, isIntType, isUintType, sameSizeUvec } from "./types"; -import { - DynoOutput, - type DynoVal, - type DynoValue, - type HasDynoOut, - dynoConst, - valType, -} from "./value"; -import { combine } from "./vecmat"; - -export const remapIndex = ( - index: DynoVal<"int">, - from: DynoVal<"int">, - to: DynoVal<"int">, -): DynoVal<"int"> => { - return new DynoRemapIndex({ index, from, to }); -}; -export const pcgMix = ( - value: DynoVal, -): DynoVal<"uint"> => { - return new PcgMix({ value }); -}; -export const pcgNext = (state: DynoVal<"uint">): DynoVal<"uint"> => { - return new PcgNext({ state }); -}; -export const pcgHash = (state: DynoVal<"uint">): DynoVal<"uint"> => { - return new PcgHash({ state }); -}; -export const hash = ( - value: DynoVal, -): DynoVal<"uint"> => { - return new Hash({ value }); -}; -export const hash2 = ( - value: DynoVal, -): DynoVal<"uvec2"> => { - return new Hash2({ value }); -}; -export const hash3 = ( - value: DynoVal, -): DynoVal<"uvec3"> => { - return new Hash3({ value }); -}; -export const hash4 = ( - value: DynoVal, -): DynoVal<"uvec4"> => { - return new Hash4({ value }); -}; -export const hashFloat = ( - value: DynoVal, -): DynoVal<"float"> => { - return new HashFloat({ value }); -}; -export const hashVec2 = ( - value: DynoVal, -): DynoVal<"vec2"> => { - return new HashVec2({ value }); -}; -export const hashVec3 = ( - value: DynoVal, -): DynoVal<"vec3"> => { - return new HashVec3({ value }); -}; -export const hashVec4 = ( - value: DynoVal, -): DynoVal<"vec4"> => { - return new HashVec4({ value }); -}; -export const normalizedDepth = ( - z: DynoVal<"float">, - zNear: DynoVal<"float">, - zFar: DynoVal<"float">, -): DynoVal<"float"> => { - return new NormalizedDepth({ z, zNear, zFar }).outputs.depth; -}; - -export class DynoRemapIndex - extends Dyno<{ from: "int"; to: "int"; index: "int" }, { index: "int" }> - implements HasDynoOut<"int"> -{ - constructor({ - from, - to, - index, - }: { from: DynoVal<"int">; to: DynoVal<"int">; index: DynoVal<"int"> }) { - super({ - inTypes: { from: "int", to: "int", index: "int" }, - outTypes: { index: "int" }, - inputs: { from, to, index }, - statements: ({ inputs, outputs }) => { - return [ - `${outputs.index} = ${inputs.index} - ${inputs.from} + ${inputs.to};`, - ]; - }, - }); - } - - dynoOut(): DynoValue<"int"> { - return new DynoOutput(this, "index"); - } -} - -export class PcgNext - extends Dyno<{ state: T }, { state: "uint" }> - implements HasDynoOut<"uint"> -{ - constructor({ state }: { state: DynoVal }) { - const type = valType(state); - super({ - inTypes: { state: type }, - outTypes: { state: "uint" }, - inputs: { state }, - globals: () => [ - unindent(` - uint pcg_next(uint state) { - return state * 747796405u + 2891336453u; - } - `), - ], - statements: ({ inputs, outputs }) => { - const toUint = - type === "uint" - ? `${inputs.state}` - : type === "int" - ? `uint(${inputs.state})` - : `floatBitsToUint(${inputs.state})`; - return [`${outputs.state} = pcg_next(${toUint});`]; - }, - }); - } - dynoOut(): DynoValue<"uint"> { - return new DynoOutput(this, "state"); - } -} - -export class PcgHash - extends Dyno<{ state: "uint" }, { hash: "uint" }> - implements HasDynoOut<"uint"> -{ - constructor({ state }: { state: DynoVal<"uint"> }) { - super({ - inTypes: { state: "uint" }, - outTypes: { hash: "uint" }, - inputs: { state }, - globals: () => [ - unindent(` - uint pcg_hash(uint state) { - uint hash = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; - return (hash >> 22u) ^ hash; - } - `), - ], - statements: ({ inputs, outputs }) => [ - `${outputs.hash} = pcg_hash(${inputs.state});`, - ], - }); - } - dynoOut(): DynoValue<"uint"> { - return new DynoOutput(this, "hash"); - } -} - -export class PcgMix - extends Dyno<{ value: T }, { state: "uint" }> - implements HasDynoOut<"uint"> -{ - constructor({ value }: { value: DynoVal }) { - const type = valType(value); - const tempType = sameSizeUvec(type); - super({ - inTypes: { value: type }, - outTypes: { state: "uint" }, - inputs: { value }, - globals: () => [ - unindent(` - uint pcg_mix(uint value) { - return value; - } - uint pcg_mix(uvec2 value) { - return value.x + 0x9e3779b9u * value.y; - } - uint pcg_mix(uvec3 value) { - return value.x + 0x9e3779b9u * value.y + 0x85ebca6bu * value.z; - } - uint pcg_mix(uvec4 value) { - return value.x + 0x9e3779b9u * value.y + 0x85ebca6bu * value.z + 0xc2b2ae35u * value.w; - } - `), - ], - statements: ({ inputs, outputs }) => { - const toUvec = isUintType(type) - ? `${inputs.value}` - : isIntType(type) - ? `${tempType}(${inputs.value})` - : `floatBitsToUint(${inputs.value})`; - return [ - `${tempType} bits = ${toUvec};`, - `${outputs.state} = pcg_mix(bits);`, - ]; - }, - }); - } - dynoOut(): DynoValue<"uint"> { - return new DynoOutput(this, "state"); - } -} - -export class Hash - extends DynoBlock<{ value: T }, { hash: "uint" }> - implements HasDynoOut<"uint"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "uint" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - let state = new PcgMix({ value: value }).outputs.state; - state = new PcgNext({ state }).outputs.state; - return new PcgHash({ state }).outputs; - }, - }); - } - dynoOut(): DynoValue<"uint"> { - return new DynoOutput(this, "hash"); - } -} - -export class Hash2 - extends DynoBlock<{ value: T }, { hash: "uvec2" }> - implements HasDynoOut<"uvec2"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "uvec2" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - let state = new PcgMix({ value: value }).outputs.state; - state = new PcgNext({ state }).outputs.state; - const x = new PcgHash({ state }).outputs.hash; - state = new PcgNext({ state }).outputs.state; - const y = new PcgHash({ state }).outputs.hash; - return { hash: combine({ vectorType: "uvec2", x, y }) }; - }, - }); - } - dynoOut(): DynoValue<"uvec2"> { - return new DynoOutput(this, "hash"); - } -} - -export class Hash3 - extends DynoBlock<{ value: T }, { hash: "uvec3" }> - implements HasDynoOut<"uvec3"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "uvec3" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - let state = new PcgMix({ value: value }).outputs.state; - state = new PcgNext({ state }).outputs.state; - const x = new PcgHash({ state }).outputs.hash; - state = new PcgNext({ state }).outputs.state; - const y = new PcgHash({ state }).outputs.hash; - state = new PcgNext({ state }).outputs.state; - const z = new PcgHash({ state }).outputs.hash; - return { hash: combine({ vectorType: "uvec3", x, y, z }) }; - }, - }); - } - dynoOut(): DynoValue<"uvec3"> { - return new DynoOutput(this, "hash"); - } -} - -export class Hash4 - extends DynoBlock<{ value: T }, { hash: "uvec4" }> - implements HasDynoOut<"uvec4"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "uvec4" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - let state = new PcgMix({ value: value }).outputs.state; - state = new PcgNext({ state }).outputs.state; - const x = new PcgHash({ state }).outputs.hash; - state = new PcgNext({ state }).outputs.state; - const y = new PcgHash({ state }).outputs.hash; - state = new PcgNext({ state }).outputs.state; - const z = new PcgHash({ state }).outputs.hash; - state = new PcgNext({ state }).outputs.state; - const w = new PcgHash({ state }).outputs.hash; - return { hash: combine({ vectorType: "uvec4", x, y, z, w }) }; - }, - }); - } - dynoOut(): DynoValue<"uvec4"> { - return new DynoOutput(this, "hash"); - } -} - -export class HashFloat - extends DynoBlock<{ value: T }, { hash: "float" }> - implements HasDynoOut<"float"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "float" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - const word = hash(value); - return { hash: mul(float(word), dynoConst("float", 1 / 2 ** 32)) }; - }, - }); - } - dynoOut(): DynoValue<"float"> { - return new DynoOutput(this, "hash"); - } -} - -export class HashVec2 - extends DynoBlock<{ value: T }, { hash: "vec2" }> - implements HasDynoOut<"vec2"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "vec2" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - const words = hash2(value); - return { hash: mul(vec2(words), dynoConst("float", 1 / 2 ** 32)) }; - }, - }); - } - dynoOut(): DynoValue<"vec2"> { - return new DynoOutput(this, "hash"); - } -} - -export class HashVec3 - extends DynoBlock<{ value: T }, { hash: "vec3" }> - implements HasDynoOut<"vec3"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "vec3" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - const words = hash3(value); - return { hash: mul(vec3(words), dynoConst("float", 1 / 2 ** 32)) }; - }, - }); - } - dynoOut(): DynoValue<"vec3"> { - return new DynoOutput(this, "hash"); - } -} - -export class HashVec4 - extends DynoBlock<{ value: T }, { hash: "vec4" }> - implements HasDynoOut<"vec4"> -{ - constructor({ value }: { value: DynoVal }) { - super({ - inTypes: { value: valType(value) }, - outTypes: { hash: "vec4" }, - inputs: { value }, - construct: ({ value }) => { - if (!value) { - throw new Error("value is required"); - } - const words = hash4(value); - return { hash: mul(vec4(words), dynoConst("float", 1 / 2 ** 32)) }; - }, - }); - } - dynoOut(): DynoValue<"vec4"> { - return new DynoOutput(this, "hash"); - } -} - -export class NormalizedDepth - extends Dyno< - { z: "float"; zNear: "float"; zFar: "float" }, - { depth: "float" } - > - implements HasDynoOut<"float"> -{ - constructor({ - z, - zNear, - zFar, - }: { z: DynoVal<"float">; zNear: DynoVal<"float">; zFar: DynoVal<"float"> }) { - super({ - inTypes: { z: "float", zNear: "float", zFar: "float" }, - outTypes: { depth: "float" }, - inputs: { z, zNear, zFar }, - statements: ({ inputs, outputs }) => [ - `float clamped = clamp(${inputs.z}, ${inputs.zNear}, ${inputs.zFar});`, - `${outputs.depth} = (log2(clamped + 1.0) - log2(${inputs.zNear} + 1.0)) / (log2(${inputs.zFar} + 1.0) - log2(${inputs.zNear} + 1.0));`, - ], - }); - } - - dynoOut(): DynoValue<"float"> { - return new DynoOutput(this, "depth"); - } -} diff --git a/src/dyno/value.ts b/src/dyno/value.ts deleted file mode 100644 index 88227bb..0000000 --- a/src/dyno/value.ts +++ /dev/null @@ -1,289 +0,0 @@ -import * as THREE from "three"; - -import type { Dyno, IOTypes } from "./base"; -import { - type DynoJsType, - type DynoType, - type SimpleTypes, - isAllFloatType, - isBoolType, - isIntType, - isUintType, - numberAsFloat, - numberAsInt, - numberAsUint, -} from "./types"; - -export type DynoVal = DynoValue | HasDynoOut; - -export function valType(val: DynoVal): T { - if (val instanceof DynoValue) { - return val.type; - } - const value = val.dynoOut(); - return value.type; -} - -export interface HasDynoOut { - dynoOut(): DynoValue; -} - -export class DynoValue { - type: T; - // This field prevents TypeScript structural matching on objects with a "type" field - private __isDynoValue = true; - - constructor(type: T) { - this.type = type; - } -} - -export class DynoOutput< - T extends DynoType, - InTypes extends IOTypes, - OutTypes extends IOTypes, -> extends DynoValue { - dyno: Dyno; - key: string; - - constructor(dyno: Dyno, key: string) { - super(dyno.outTypes[key] as T); - this.dyno = dyno; - this.key = key; - } -} - -export class DynoLiteral extends DynoValue { - literal: string; - - constructor(type: T, literal: string) { - super(type); - this.literal = literal; - } - - getLiteral(): string { - return this.literal; - } -} - -export function dynoLiteral( - type: T, - literal: string, -): DynoLiteral { - return new DynoLiteral(type, literal); -} - -export class DynoConst extends DynoLiteral { - value: DynoJsType; - - constructor(type: T, value: DynoJsType) { - super(type, ""); - this.value = value; - } - - getLiteral(): string { - const { type, value } = this; - switch (type) { - case "bool": - return value ? "true" : "false"; - case "uint": - return numberAsUint(value as number); - case "int": - return numberAsInt(value as number); - case "float": - return numberAsFloat(value as number); - case "bvec2": { - const v = value as [boolean, boolean]; - return `bvec2(${v[0]}, ${v[1]})`; - } - case "uvec2": { - if (value instanceof THREE.Vector2) { - return `uvec2(${numberAsUint(value.x)}, ${numberAsUint(value.y)})`; - } - const v = value as [number, number] | Uint32Array; - return `uvec2(${numberAsUint(v[0])}, ${numberAsUint(v[1])})`; - } - case "ivec2": { - if (value instanceof THREE.Vector2) { - return `ivec2(${numberAsInt(value.x)}, ${numberAsInt(value.y)})`; - } - const v = value as [number, number] | Int32Array; - return `ivec2(${numberAsInt(v[0])}, ${numberAsInt(v[1])})`; - } - case "vec2": { - if (value instanceof THREE.Vector2) { - return `vec2(${numberAsFloat(value.x)}, ${numberAsFloat(value.y)})`; - } - const v = value as [number, number] | Float32Array; - return `vec2(${numberAsFloat(v[0])}, ${numberAsFloat(v[1])})`; - } - case "bvec3": { - const v = value as [boolean, boolean, boolean]; - return `bvec3(${v[0]}, ${v[1]}, ${v[2]})`; - } - case "uvec3": { - if (value instanceof THREE.Vector3) { - return `uvec3(${numberAsUint(value.x)}, ${numberAsUint(value.y)}, ${numberAsUint(value.z)})`; - } - const v = value as [number, number, number] | Uint32Array; - return `uvec3(${numberAsUint(v[0])}, ${numberAsUint(v[1])}, ${numberAsUint(v[2])})`; - } - case "ivec3": { - if (value instanceof THREE.Vector3) { - return `ivec3(${numberAsInt(value.x)}, ${numberAsInt(value.y)}, ${numberAsInt(value.z)})`; - } - const v = value as [number, number, number] | Int32Array; - return `ivec3(${numberAsInt(v[0])}, ${numberAsInt(v[1])}, ${numberAsInt(v[2])})`; - } - case "vec3": { - if (value instanceof THREE.Vector3) { - return `vec3(${numberAsFloat(value.x)}, ${numberAsFloat(value.y)}, ${numberAsFloat(value.z)})`; - } - const v = value as [number, number, number] | Float32Array; - return `vec3(${numberAsFloat(v[0])}, ${numberAsFloat(v[1])}, ${numberAsFloat(v[2])})`; - } - case "bvec4": { - const v = value as [boolean, boolean, boolean, boolean]; - return `bvec4(${v[0]}, ${v[1]}, ${v[2]}, ${v[3]})`; - } - case "uvec4": { - if (value instanceof THREE.Vector4) { - return `uvec4(${numberAsUint(value.x)}, ${numberAsUint(value.y)}, ${numberAsUint(value.z)}, ${numberAsUint(value.w)})`; - } - const v = value as [number, number, number, number] | Uint32Array; - return `uvec4(${numberAsUint(v[0])}, ${numberAsUint(v[1])}, ${numberAsUint(v[2])}, ${numberAsUint(v[3])})`; - } - case "ivec4": { - if (value instanceof THREE.Vector4) { - return `ivec4(${numberAsInt(value.x)}, ${numberAsInt(value.y)}, ${numberAsInt(value.z)}, ${numberAsInt(value.w)})`; - } - const v = value as [number, number, number, number] | Int32Array; - return `ivec4(${numberAsInt(v[0])}, ${numberAsInt(v[1])}, ${numberAsInt(v[2])}, ${numberAsInt(v[3])})`; - } - case "vec4": { - if (value instanceof THREE.Vector4) { - return `vec4(${numberAsFloat(value.x)}, ${numberAsFloat(value.y)}, ${numberAsFloat(value.z)}, ${numberAsFloat(value.w)})`; - } - if (value instanceof THREE.Quaternion) { - return `vec4(${numberAsFloat(value.x)}, ${numberAsFloat(value.y)}, ${numberAsFloat(value.z)}, ${numberAsFloat(value.w)})`; - } - const v = value as [number, number, number, number] | Float32Array; - return `vec4(${numberAsFloat(v[0])}, ${numberAsFloat(v[1])}, ${numberAsFloat(v[2])}, ${numberAsFloat(v[3])})`; - } - case "mat2": - case "mat2x2": { - const m = value as DynoJsType<"mat2">; - const e = - m instanceof THREE.Matrix2 ? m.elements : (value as Float32Array); - const arg = new Array(4).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat2x3": { - const e = value as DynoJsType<"mat2x3">; - const arg = new Array(6).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat2x4": { - const e = value as DynoJsType<"mat2x4">; - const arg = new Array(8).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat3": - case "mat3x3": { - const m = value as DynoJsType<"mat3">; - const e = - m instanceof THREE.Matrix3 ? m.elements : (value as Float32Array); - const arg = new Array(9).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat3x2": { - const e = value as DynoJsType<"mat3x2">; - const arg = new Array(6).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat3x4": { - const e = value as DynoJsType<"mat3x4">; - const arg = new Array(12).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat4": - case "mat4x4": { - const m = value as DynoJsType<"mat4">; - const e = - m instanceof THREE.Matrix4 ? m.elements : (value as Float32Array); - const arg = new Array(16).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat4x2": { - const e = value as DynoJsType<"mat4x2">; - const arg = new Array(8).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - case "mat4x3": { - const e = value as DynoJsType<"mat4x3">; - const arg = new Array(12).fill(0).map((_, i) => numberAsFloat(e[i])); - return `${type as string}(${arg.join(", ")})`; - } - default: - throw new Error(`Type not implemented: ${String(type)}`); - } - } -} - -export function dynoConst( - type: T, - value: DynoJsType, -): DynoConst { - return new DynoConst(type, value); -} - -export function literalZero(type: SimpleTypes): string { - const typeString = String(type); - if (isBoolType(type)) { - return `${typeString}(false)`; - } - if (isAllFloatType(type)) { - return `${typeString}(0.0)`; - } - if (isIntType(type)) { - return `${typeString}(0)`; - } - if (isUintType(type)) { - return `${typeString}(0u)`; - } - throw new Error(`Type not implemented: ${typeString}`); -} - -export function literalOne(type: SimpleTypes): string { - const typeString = String(type); - if (isBoolType(type)) { - return `${typeString}(true)`; - } - if (isAllFloatType(type)) { - return `${typeString}(1.0)`; - } - if (isIntType(type)) { - return `${typeString}(1)`; - } - if (isUintType(type)) { - return `${typeString}(1u)`; - } - throw new Error(`Type not implemented: ${typeString}`); -} - -export function literalNegOne(type: SimpleTypes): string { - const typeString = String(type); - if (isBoolType(type)) { - return `${typeString}(true)`; - } - if (isAllFloatType(type)) { - return `${typeString}(-1.0)`; - } - if (isIntType(type)) { - return `${typeString}(-1)`; - } - if (isUintType(type)) { - return `${typeString}(0xFFFFFFFFu)`; - } - throw new Error(`Type not implemented: ${typeString}`); -} diff --git a/src/dyno/vecmat.ts b/src/dyno/vecmat.ts deleted file mode 100644 index 56e831f..0000000 --- a/src/dyno/vecmat.ts +++ /dev/null @@ -1,835 +0,0 @@ -import { BinaryOp, Dyno, TrinaryOp, UnaryOp } from "./base"; -import { - type FloatTypes, - type IntTypes, - type MatFloatTypes, - type SquareMatTypes, - type UintTypes, - type VectorElementType, - type VectorTypes, - isFloatType, - isIntType, - isUintType, - vectorDim, - vectorElementType, -} from "./types"; -import { - DynoOutput, - type DynoVal, - type DynoValue, - type HasDynoOut, - literalZero, - valType, -} from "./value"; - -export const length = ( - a: DynoVal, -): DynoVal<"float"> => new Length({ a }); -export const distance = ( - a: DynoVal, - b: DynoVal, -): DynoVal<"float"> => new Distance({ a, b }); -export const dot = ( - a: DynoVal, - b: DynoVal, -): DynoVal<"float"> => new Dot({ a, b }); -export const cross = ( - a: DynoVal<"vec3">, - b: DynoVal<"vec3">, -): DynoVal<"vec3"> => new Cross({ a, b }); -export const normalize = ( - a: DynoVal, -): DynoVal => new Normalize({ a }); -export const faceforward = ( - a: DynoVal, - b: DynoVal, - c: DynoVal, -): DynoVal => new FaceForward({ a, b, c }); -export const reflectVec = ( - incident: DynoVal, - normal: DynoVal, -): DynoVal => new ReflectVec({ incident, normal }); -export const refractVec = ( - incident: DynoVal, - normal: DynoVal, - eta: DynoVal<"float">, -): DynoVal => new RefractVec({ incident, normal, eta }); -export const split = (vector: DynoVal): Split => - new Split({ vector }); -export const combine = >({ - vector, - vectorType, - x, - y, - z, - w, - r, - g, - b, - a, -}: { - vector?: DynoVal; - vectorType?: V; - x?: DynoVal; - y?: DynoVal; - z?: DynoVal; - w?: DynoVal; - r?: DynoVal; - g?: DynoVal; - b?: DynoVal; - a?: DynoVal; -}): DynoVal => new Combine({ vector, vectorType, x, y, z, w, r, g, b, a }); -export const projectH = ( - a: DynoVal, -): DynoVal> => new ProjectH({ a }); -export const extendVec = ( - a: DynoVal, - b: DynoVal<"float">, -): DynoVal> => new ExtendVec({ a, b }); -export const swizzle = ( - a: DynoVal, - select: S, -): DynoVal>> => - new Swizzle({ vector: a, select }); -export const compMult = ( - a: DynoVal, - b: DynoVal, -): DynoVal => new CompMult({ a, b }); -export const outer = < - A extends "vec2" | "vec3" | "vec4", - B extends "vec2" | "vec3" | "vec4", ->( - a: DynoVal, - b: DynoVal, -): DynoVal> => new Outer({ a, b }); -export const transpose = ( - a: DynoVal, -): DynoVal> => new Transpose({ a }); -export const determinant = ( - a: DynoVal, -): DynoVal<"float"> => new Determinant({ a }); -export const inverse = (a: DynoVal): DynoVal => - new Inverse({ a }); - -export class Length extends UnaryOp< - A, - "float", - "length" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outTypeFunc: (aType) => "float", outKey: "length" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.length} = length(${inputs.a});`, - ]; - } -} - -export class Distance extends BinaryOp< - A, - A, - "float", - "distance" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "distance", outTypeFunc: (aType, bType) => "float" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.distance} = distance(${inputs.a}, ${inputs.b});`, - ]; - } -} - -export class Dot extends BinaryOp< - A, - A, - "float", - "dot" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "dot", outTypeFunc: (aType, bType) => "float" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.dot} = dot(${inputs.a}, ${inputs.b});`, - ]; - } -} - -export class Cross extends BinaryOp<"vec3", "vec3", "vec3", "cross"> { - constructor({ a, b }: { a: DynoVal<"vec3">; b: DynoVal<"vec3"> }) { - super({ a, b, outKey: "cross", outTypeFunc: (aType, bType) => "vec3" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.cross} = cross(${inputs.a}, ${inputs.b});`, - ]; - } -} - -export class Normalize extends UnaryOp< - A, - A, - "normalize" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outTypeFunc: (aType) => aType, outKey: "normalize" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.normalize} = normalize(${inputs.a});`, - ]; - } -} - -type ProjectHOutput = A extends "vec3" - ? "vec2" - : A extends "vec4" - ? "vec3" - : never; - -function projectHOutputType( - type: A, -): ProjectHOutput { - if (type === "vec3") { - return "vec2" as ProjectHOutput; - } - if (type === "vec4") { - return "vec3" as ProjectHOutput; - } - throw new Error("Invalid type"); -} - -export class ProjectH extends UnaryOp< - A, - ProjectHOutput, - "projected" -> { - constructor({ a }: { a: DynoVal }) { - super({ - a, - outTypeFunc: (aType) => projectHOutputType(aType), - outKey: "projected", - }); - this.statements = ({ inputs, outputs }) => { - if (this.inTypes.a === "vec3") { - return [`${outputs.projected} = ${inputs.a}.xy / ${inputs.a}.z;`]; - } - if (this.inTypes.a === "vec4") { - return [`${outputs.projected} = ${inputs.a}.xyz / ${inputs.a}.w;`]; - } - throw new Error("Invalid type"); - }; - } -} - -type ExtendVecOutput = A extends "float" - ? "vec2" - : A extends "vec2" - ? "vec3" - : A extends "vec3" - ? "vec4" - : never; - -function extendVecOutputType( - type: A, -): ExtendVecOutput { - if (type === "float") return "vec2" as ExtendVecOutput; - if (type === "vec2") return "vec3" as ExtendVecOutput; - if (type === "vec3") return "vec4" as ExtendVecOutput; - throw new Error("Invalid type"); -} - -export class ExtendVec extends BinaryOp< - A, - "float", - ExtendVecOutput, - "extend" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal<"float"> }) { - const type = valType(a); - const outType = extendVecOutputType(type); - super({ a, b, outKey: "extend", outTypeFunc: () => outType }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.extend} = ${outType}(${inputs.a}, ${inputs.b});`, - ]; - } -} - -export class FaceForward extends TrinaryOp< - A, - A, - A, - A, - "forward" -> { - constructor({ a, b, c }: { a: DynoVal; b: DynoVal; c: DynoVal }) { - super({ - a, - b, - c, - outKey: "forward", - outTypeFunc: (aType, bType, cType) => aType, - }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.forward} = faceforward(${inputs.a}, ${inputs.b}, ${inputs.c});`, - ]; - } -} - -export class ReflectVec extends BinaryOp< - A, - A, - A, - "reflection" -> { - constructor({ - incident, - normal, - }: { incident: DynoVal; normal: DynoVal }) { - super({ - a: incident, - b: normal, - outKey: "reflection", - outTypeFunc: (aType, bType) => aType, - }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.reflection} = reflect(${inputs.a}, ${inputs.b});`, - ]; - } -} - -export class RefractVec extends TrinaryOp< - A, - A, - "float", - A, - "refraction" -> { - constructor({ - incident, - normal, - eta, - }: { incident: DynoVal; normal: DynoVal; eta: DynoVal<"float"> }) { - super({ - a: incident, - b: normal, - c: eta, - outKey: "refraction", - outTypeFunc: (aType, bType, cType) => aType, - }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.refraction} = refract(${inputs.a}, ${inputs.b}, ${inputs.c});`, - ]; - } -} - -export class CompMult extends BinaryOp< - A, - A, - A, - "product" -> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "product", outTypeFunc: (aType, bType) => aType }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.product} = matrixCompMult(${a}, ${b});`, - ]; - } -} - -type OuterOutput< - A extends "vec2" | "vec3" | "vec4", - B extends "vec2" | "vec3" | "vec4", -> = A extends "vec2" - ? B extends "vec2" - ? "mat2" - : B extends "vec3" - ? "mat3x2" - : B extends "vec4" - ? "mat4x2" - : never - : A extends "vec3" - ? B extends "vec2" - ? "mat2x3" - : B extends "vec3" - ? "mat3" - : B extends "vec4" - ? "mat4x3" - : never - : A extends "vec4" - ? B extends "vec2" - ? "mat2x4" - : B extends "vec3" - ? "mat3x4" - : B extends "vec4" - ? "mat4" - : never - : never; - -function outerOutputType< - A extends "vec2" | "vec3" | "vec4", - B extends "vec2" | "vec3" | "vec4", ->(aType: A, bType: B): OuterOutput { - if (aType === "vec2") { - if (bType === "vec2") return "mat2" as OuterOutput; - if (bType === "vec3") return "mat3x2" as OuterOutput; - if (bType === "vec4") return "mat4x2" as OuterOutput; - } - if (aType === "vec3") { - if (bType === "vec2") return "mat2x3" as OuterOutput; - if (bType === "vec3") return "mat3" as OuterOutput; - if (bType === "vec4") return "mat4x3" as OuterOutput; - } - if (aType === "vec4") { - if (bType === "vec2") return "mat2x4" as OuterOutput; - if (bType === "vec3") return "mat3x4" as OuterOutput; - if (bType === "vec4") return "mat4" as OuterOutput; - } - throw new Error(`Invalid outer type: ${aType}, ${bType}`); -} - -export class Outer< - A extends "vec2" | "vec3" | "vec4", - B extends "vec2" | "vec3" | "vec4", -> extends BinaryOp, "outer"> { - constructor({ a, b }: { a: DynoVal; b: DynoVal }) { - super({ a, b, outKey: "outer", outTypeFunc: outerOutputType }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.outer} = outerProduct(${inputs.a}, ${inputs.b});`, - ]; - } -} - -type TransposeOutput = A extends SquareMatTypes - ? A - : A extends "mat2x3" - ? "mat3x2" - : A extends "mat2x4" - ? "mat4x2" - : A extends "mat3x2" - ? "mat2x3" - : A extends "mat3x4" - ? "mat4x3" - : A extends "mat4x2" - ? "mat2x4" - : A extends "mat4x3" - ? "mat3x4" - : never; - -function transposeOutputType( - type: A, -): TransposeOutput { - if (type === "mat2") return "mat2" as TransposeOutput; - if (type === "mat3") return "mat3" as TransposeOutput; - if (type === "mat4") return "mat4" as TransposeOutput; - if (type === "mat2x2") return "mat2x2" as TransposeOutput; - if (type === "mat2x3") return "mat3x2" as TransposeOutput; - if (type === "mat2x4") return "mat4x2" as TransposeOutput; - if (type === "mat3x2") return "mat2x3" as TransposeOutput; - if (type === "mat3x3") return "mat3x3" as TransposeOutput; - if (type === "mat3x4") return "mat4x3" as TransposeOutput; - if (type === "mat4x2") return "mat2x4" as TransposeOutput; - if (type === "mat4x3") return "mat3x4" as TransposeOutput; - if (type === "mat4x4") return "mat4x4" as TransposeOutput; - throw new Error(`Invalid transpose type: ${type}`); -} - -export class Transpose extends UnaryOp< - A, - TransposeOutput, - "transpose" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "transpose", outTypeFunc: transposeOutputType }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.transpose} = transpose(${inputs.a});`, - ]; - } -} - -export class Determinant extends UnaryOp< - A, - "float", - "det" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "det", outTypeFunc: (aType) => "float" }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.det} = determinant(${inputs.a});`, - ]; - } -} - -export class Inverse extends UnaryOp< - A, - A, - "inverse" -> { - constructor({ a }: { a: DynoVal }) { - super({ a, outKey: "inverse", outTypeFunc: (aType) => aType }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.inverse} = inverse(${a});`, - ]; - } -} - -type SplitOutTypes = A extends "vec2" - ? { x: "float"; y: "float"; r: "float"; g: "float" } - : A extends "vec3" - ? { x: "float"; y: "float"; z: "float"; r: "float"; g: "float"; b: "float" } - : A extends "vec4" - ? { - x: "float"; - y: "float"; - z: "float"; - w: "float"; - r: "float"; - g: "float"; - b: "float"; - a: "float"; - } - : A extends "ivec2" - ? { x: "int"; y: "int"; r: "int"; g: "int" } - : A extends "ivec3" - ? { x: "int"; y: "int"; z: "int"; r: "int"; g: "int"; b: "int" } - : A extends "ivec4" - ? { - x: "int"; - y: "int"; - z: "int"; - w: "int"; - r: "int"; - g: "int"; - b: "int"; - a: "int"; - } - : A extends "uvec2" - ? { x: "uint"; y: "uint"; r: "uint"; g: "uint" } - : A extends "uvec3" - ? { - x: "uint"; - y: "uint"; - z: "uint"; - r: "uint"; - g: "uint"; - b: "uint"; - } - : A extends "uvec4" - ? { - x: "uint"; - y: "uint"; - z: "uint"; - w: "uint"; - r: "uint"; - g: "uint"; - b: "uint"; - a: "uint"; - } - : never; - -function splitOutTypes(type: A): SplitOutTypes { - const result = (value: unknown) => value as SplitOutTypes; - switch (type) { - case "vec2": - return result({ x: "float", y: "float", r: "float", g: "float" }); - case "vec3": - return result({ - x: "float", - y: "float", - z: "float", - r: "float", - g: "float", - b: "float", - }); - case "vec4": - return result({ - x: "float", - y: "float", - z: "float", - w: "float", - r: "float", - g: "float", - b: "float", - a: "float", - }); - case "ivec2": - return result({ x: "int", y: "int", r: "int", g: "int" }); - case "ivec3": - return result({ - x: "int", - y: "int", - z: "int", - r: "int", - g: "int", - b: "int", - }); - case "ivec4": - return result({ - x: "int", - y: "int", - z: "int", - w: "int", - r: "int", - g: "int", - b: "int", - a: "int", - }); - case "uvec2": - return result({ x: "uint", y: "uint", r: "uint", g: "uint" }); - case "uvec3": - return result({ - x: "uint", - y: "uint", - z: "uint", - r: "uint", - g: "uint", - b: "uint", - }); - case "uvec4": - return result({ - x: "uint", - y: "uint", - z: "uint", - w: "uint", - r: "uint", - g: "uint", - b: "uint", - a: "uint", - }); - default: - throw new Error(`Invalid vector type: ${type}`); - } -} - -export class Split extends Dyno< - { vector: V }, - SplitOutTypes -> { - constructor({ vector }: { vector: DynoVal }) { - const type = valType(vector); - const inTypes = { vector: type }; - const outTypes = splitOutTypes(inTypes.vector); - super({ inTypes, outTypes, inputs: { vector } }); - this.statements = ({ inputs, outputs }) => { - const { x, y, z, w, r, g, b, a } = outputs as unknown as Record< - string, - string - >; - const { vector } = inputs; - return [ - x ? `${x} = ${vector}.x;` : null, - y ? `${y} = ${vector}.y;` : null, - z ? `${z} = ${vector}.z;` : null, - w ? `${w} = ${vector}.w;` : null, - r ? `${r} = ${vector}.r;` : null, - g ? `${g} = ${vector}.g;` : null, - b ? `${b} = ${vector}.b;` : null, - a ? `${a} = ${vector}.a;` : null, - ].filter(Boolean) as string[]; - }; - } -} - -export class Combine> - extends Dyno & { vector: V }, { vector: V }> - implements HasDynoOut -{ - constructor({ - vector, - vectorType, - x, - y, - z, - w, - r, - g, - b, - a, - }: { - vector?: DynoVal; - vectorType?: V; - x?: DynoVal; - y?: DynoVal; - z?: DynoVal; - w?: DynoVal; - r?: DynoVal; - g?: DynoVal; - b?: DynoVal; - a?: DynoVal; - }) { - if (!vector && !vectorType) { - throw new Error("Either vector or vectorType must be provided"); - } - const vType = vectorType ?? valType(vector as DynoVal); - const elType = vectorElementType(vType); - const dim = vectorDim(vType); - - const inTypes = { - vector: vType, - x: elType, - y: elType, - r: elType, - g: elType, - } as unknown as SplitOutTypes & { vector: V }; - const inputs = { vector, x, y, r, g }; - if (dim >= 3) { - Object.assign(inTypes, { z: elType, b: elType }); - Object.assign(inputs, { z, b }); - } - if (dim >= 4) { - Object.assign(inTypes, { w: elType, a: elType }); - Object.assign(inputs, { w, a }); - } - // @ts-ignore - super({ inTypes, outTypes: { vector: vType }, inputs }); - this.statements = ({ inputs, outputs }) => { - const { vector } = outputs; - const { - vector: input, - x, - y, - z, - w, - r, - g, - b, - a, - } = inputs as Record; - const statements = [ - `${vector}.x = ${x ?? r ?? (input ? `${input}.x` : literalZero(elType))};`, - `${vector}.y = ${y ?? g ?? (input ? `${input}.y` : literalZero(elType))};`, - ]; - if (dim >= 3) - statements.push( - `${vector}.z = ${z ?? b ?? (input ? `${input}.z` : literalZero(elType))};`, - ); - if (dim >= 4) - statements.push( - `${vector}.w = ${w ?? a ?? (input ? `${input}.w` : literalZero(elType))};`, - ); - return statements; - }; - } - - dynoOut(): DynoValue { - return new DynoOutput & { vector: V }, { vector: V }>( - this, - "vector", - ); - } -} - -type SwizzleOutput< - A extends VectorTypes, - Len extends number, -> = A extends FloatTypes - ? Len extends 1 - ? "float" - : Len extends 2 - ? "vec2" - : Len extends 3 - ? "vec3" - : Len extends 4 - ? "vec4" - : never - : A extends IntTypes - ? Len extends 1 - ? "int" - : Len extends 2 - ? "ivec2" - : Len extends 3 - ? "ivec3" - : Len extends 4 - ? "ivec4" - : never - : A extends UintTypes - ? Len extends 1 - ? "uint" - : Len extends 2 - ? "uvec2" - : Len extends 3 - ? "uvec3" - : Len extends 4 - ? "uvec4" - : never - : never; - -type SwizzleSelectLen = S extends Swizzle1Select - ? 1 - : S extends Swizzle2Select - ? 2 - : S extends Swizzle3Select - ? 3 - : S extends Swizzle4Select - ? 4 - : never; - -function swizzleOutputType( - type: A, - swizzle: S, -): SwizzleOutput> { - let result = null; - if (isFloatType(type)) { - result = - swizzle.length === 1 - ? "float" - : swizzle.length === 2 - ? "vec2" - : swizzle.length === 3 - ? "vec3" - : swizzle.length === 4 - ? "vec4" - : null; - } else if (isIntType(type)) { - result = - swizzle.length === 1 - ? "int" - : swizzle.length === 2 - ? "ivec2" - : swizzle.length === 3 - ? "ivec3" - : swizzle.length === 4 - ? "ivec4" - : null; - } else if (isUintType(type)) { - result = - swizzle.length === 1 - ? "uint" - : swizzle.length === 2 - ? "uvec2" - : swizzle.length === 3 - ? "uvec3" - : swizzle.length === 4 - ? "uvec4" - : null; - } - if (result == null) { - throw new Error(`Invalid swizzle: ${swizzle}`); - } - return result as SwizzleOutput>; -} - -type Swizzle1Select = `${"x" | "y" | "z" | "w"}|${"r" | "g" | "b" | "a"}`; -type Swizzle2Select = - | `${"x" | "y" | "z" | "w"}${"x" | "y" | "z" | "w"}` - | `${"r" | "g" | "b" | "a"}${"r" | "g" | "b" | "a"}`; -type Swizzle3Select = - | `${"x" | "y" | "z" | "w"}${"x" | "y" | "z" | "w"}${"x" | "y" | "z" | "w"}` - | `${"r" | "g" | "b" | "a"}${"r" | "g" | "b" | "a"}${"r" | "g" | "b" | "a"}`; -type Swizzle4Select = - | `${"x" | "y" | "z" | "w"}${"x" | "y" | "z" | "w"}${"x" | "y" | "z" | "w"}${"x" | "y" | "z" | "w"}` - | `${"r" | "g" | "b" | "a"}${"r" | "g" | "b" | "a"}${"r" | "g" | "b" | "a"}${"r" | "g" | "b" | "a"}`; -type SwizzleSelect = - | Swizzle1Select - | Swizzle2Select - | Swizzle3Select - | Swizzle4Select; - -export class Swizzle< - A extends VectorTypes, - S extends SwizzleSelect, -> extends UnaryOp>, "swizzle"> { - constructor({ vector, select }: { vector: DynoVal; select: S }) { - super({ - a: vector, - outKey: "swizzle", - outTypeFunc: (aType) => swizzleOutputType(aType, select), - }); - this.statements = ({ inputs, outputs }) => [ - `${outputs.swizzle} = ${inputs.a}.${select};`, - ]; - } -} diff --git a/src/encoding/ExtendedSplats.ts b/src/encoding/ExtendedSplats.ts new file mode 100644 index 0000000..fedf1e8 --- /dev/null +++ b/src/encoding/ExtendedSplats.ts @@ -0,0 +1,464 @@ +import * as THREE from "three"; +import type { IterableSplatData, SplatCallback, SplatData } from "../Splat"; +import { + LN_SCALE_MAX, + LN_SCALE_MIN, + SH_DEGREE_TO_NUM_COEFF, + SPLAT_TEX_WIDTH, +} from "../defines"; +import { + computeMaxSplats, + decodeQuatOctXy88R8, + encodeQuatOctXy88R8, + floatBitsToUint, + floatToUint8, + getTextureSize, + uintBitsToFloat, +} from "../utils"; +import type { ResizableSplatEncoder } from "./encoder"; + +export type ExtendedSplatsOptions = { + // Reserve space for at least this many splats when constructing the collection + // initially. The array will automatically resize past maxSplats so setting it is + // an optional optimization. (default: 0) + maxSplats?: number; + // Override number of splats in packed array to use only a subset. + // (default: length of packed array / 4) + numSplats?: number; + numSh?: number; +}; + +export class ExtendedSplats implements IterableSplatData { + maxSplats = 0; + numSplats = 0; + numSh = 0; + packedArray1: Uint32Array; + packedArray2: Uint32Array; + packedShArray: Uint8Array | null; + + private splatTexture1: THREE.DataArrayTexture | null = null; + private splatTexture2: THREE.DataArrayTexture | null = null; + private shTexture: THREE.DataArrayTexture | null = null; + private needsUpdate = false; + + constructor( + packedArray1: Uint32Array, + packedArray2: Uint32Array, + packedShArray: Uint8Array | null, + options: ExtendedSplatsOptions, + ) { + this.packedArray1 = packedArray1; + this.packedArray2 = packedArray2; + this.packedShArray = packedShArray; + // Calculate number of horizontal texture rows that could fit in array. + // A properly initialized packedArray should already take into account the + // width and height of the texture and be rounded up with padding. + this.maxSplats = Math.floor(this.packedArray1.length / 4); + this.maxSplats = + Math.floor(this.maxSplats / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; + this.numSplats = Math.min( + this.maxSplats, + options.numSplats ?? Number.POSITIVE_INFINITY, + ); + // FIXME: Derive from packedShArray length or make required argument? + this.numSh = options.numSh ?? 0; + } + + setupMaterial(material: THREE.ShaderMaterial) { + material.defines.USE_EXTENDED_SPLAT = true; + material.defines.SPLAT_DECODE_FN = "decodeExtendedSplatDefault"; + material.defines.SPLAT_SH_DECODE_FN = "decodePackedSphericalHarmonics"; + material.defines.NUM_PACKED_SH = this.numSh; + + if (!material.uniforms.packedSplats1) { + material.uniforms.splatTexture1 = { value: null }; + material.uniforms.splatTexture2 = { value: null }; + material.uniforms.shTexture = { value: null }; + material.uniforms.rgbMinMaxLnScaleMinMax = { value: new THREE.Vector4() }; + } + material.uniforms.splatTexture1.value = this.getTexture( + "splatTexture1", + "packedArray1", + ); + material.uniforms.splatTexture2.value = this.getTexture( + "splatTexture2", + "packedArray2", + ); + if (this.packedShArray) { + material.uniforms.shTexture.value = this.getTexture( + "shTexture", + "packedShArray", + ); + } + } + + getTexture( + textureKey: "splatTexture1" | "splatTexture2" | "shTexture", + arrayKey: "packedArray1" | "packedArray2" | "packedShArray", + ): THREE.DataArrayTexture | null { + if (this.needsUpdate || !this[textureKey]) { + this.needsUpdate = false; + + if (!this[arrayKey]) { + throw new Error("No packed splats"); + } + + if (this[textureKey]) { + const { width, height, depth } = this[textureKey].image; + if (this.maxSplats !== width * height * depth) { + // The existing source texture isn't the right size, so dispose it + this[textureKey].dispose(); + this[textureKey] = null; + } + } + + if (!this[textureKey]) { + // Allocate a new source texture of the right size + let { width, height, depth } = getTextureSize(this.maxSplats); + if (textureKey === "shTexture") { + width *= this.numSh; + } + this[textureKey] = new THREE.DataArrayTexture( + new Uint32Array(this[arrayKey].buffer), + width, + height, + depth, + ); + this[textureKey].format = THREE.RGBAIntegerFormat; + this[textureKey].type = THREE.UnsignedIntType; + this[textureKey].internalFormat = "RGBA32UI"; + this[textureKey].needsUpdate = true; + } else if (this[arrayKey].buffer !== this[textureKey].image.data.buffer) { + // The source texture is the right size, update the data + this[textureKey].image.data = new Uint8Array(this[arrayKey].buffer); + } + } + + return this[textureKey]; + } + + iterateCenters( + callback: (index: number, x: number, y: number, z: number) => void, + ) { + for (let i = 0; i < this.numSplats; i++) { + const i4 = i * 4; + callback( + i, + uintBitsToFloat(this.packedArray1[i4 + 0]), + uintBitsToFloat(this.packedArray1[i4 + 1]), + uintBitsToFloat(this.packedArray1[i4 + 2]), + ); + } + } + + iterateSplats(callback: SplatCallback) { + const shCoeffients = SH_DEGREE_TO_NUM_COEFF[this.numSh]; + const sh = this.numSh > 0 ? new Float32Array(shCoeffients) : undefined; + + for (let i = 0; i < this.numSplats; i++) { + const i4 = i * 4; + const word0 = this.packedArray1[i4 + 0]; + const word1 = this.packedArray1[i4 + 1]; + const word2 = this.packedArray1[i4 + 2]; + const word3 = this.packedArray1[i4 + 3]; + const word4 = this.packedArray2[i4 + 0]; + const word5 = this.packedArray2[i4 + 1]; + const word6 = this.packedArray2[i4 + 2]; + const word7 = this.packedArray2[i4 + 3]; + + const r = (word5 & 0xff) / 255; + const g = ((word5 >>> 8) & 0xff) / 255; + const b = ((word5 >>> 16) & 0xff) / 255; + const a = ((word5 >>> 24) & 0xff) / 255; + + const lnScaleScale = (LN_SCALE_MAX - LN_SCALE_MIN) / 1023.0; + const uScalesX = word3 & 0x3ff; + const scaleX = Math.exp(LN_SCALE_MIN + uScalesX * lnScaleScale); + const uScalesY = (word3 >>> 10) & 0x3ff; + const scaleY = Math.exp(LN_SCALE_MIN + uScalesY * lnScaleScale); + const uScalesZ = (word3 >>> 20) & 0x3ff; + const scaleZ = Math.exp(LN_SCALE_MIN + uScalesZ * lnScaleScale); + + decodeQuatOctXy88R8(word4, tempQuaternion); + + if (sh && this.packedShArray) { + for (let j = 0; j < shCoeffients; j++) { + sh[j] = (this.packedShArray[i * shCoeffients + j] - 127) / 127; + } + } + + callback( + i, + uintBitsToFloat(word0), + uintBitsToFloat(word1), + uintBitsToFloat(word2), + scaleX, + scaleY, + scaleZ, + tempQuaternion.x, + tempQuaternion.y, + tempQuaternion.z, + tempQuaternion.w, + a, + r, + g, + b, + sh, + ); + } + } + + dispose(): void { + if (this.splatTexture1) { + this.splatTexture1.dispose(); + this.splatTexture1.source.data = null; + } + if (this.splatTexture2) { + this.splatTexture2.dispose(); + this.splatTexture2.source.data = null; + } + if (this.shTexture) { + this.shTexture.dispose(); + this.shTexture.source.data = null; + } + this.packedArray1 = EMPTY_UINT32_ARRAY; + this.packedArray2 = EMPTY_UINT32_ARRAY; + this.packedShArray = null; + this.numSplats = -1; + this.maxSplats = 0; + } + + static encodingName = "extended"; + + static createSplatEncoder(): ResizableSplatEncoder { + const context: EncodedExtendedSplats = { + numSplats: 0, + maxSplats: 0, + numSh: 0, + packedArray1: new Uint32Array(), + packedArray2: new Uint32Array(), + packedShArray: null, + }; + // Keep track of the head when pushing splats + let head = 0; + + return { + allocate(numSplats: number, numShBands: number) { + context.numSplats = numSplats; + context.maxSplats = computeMaxSplats(numSplats); + context.numSh = numShBands; + context.packedArray1 = new Uint32Array(context.maxSplats * 4); + context.packedArray2 = new Uint32Array(context.maxSplats * 4); + if (numShBands > 0) { + context.packedShArray = new Uint8Array( + context.maxSplats * 16 * numShBands, + ); + } + }, + + setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ) { + this.setSplatCenter(i, x, y, z); + this.setSplatScales(i, scaleX, scaleY, scaleZ); + this.setSplatQuat(i, quatX, quatY, quatZ, quatW); + this.setSplatRgba(i, r, g, b, opacity); + }, + + pushSplat( + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ) { + const i = head++; + context.numSplats = head; + if (head > context.maxSplats) { + // Resize + context.maxSplats = computeMaxSplats( + Math.max(context.maxSplats, 1) * 2, + ); + context.packedArray1 = new Uint32Array( + context.packedArray1.buffer.transfer(context.maxSplats * 16), + ); + context.packedArray2 = new Uint32Array( + context.packedArray2.buffer.transfer(context.maxSplats * 16), + ); + if (context.packedShArray) { + context.packedShArray = new Uint8Array( + context.packedShArray.buffer.transfer( + context.maxSplats * 16 * context.numSh, + ), + ); + } + } + this.setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ); + }, + + setSplatCenter(index, x, y, z) { + const i4 = index * 4; + context.packedArray1[i4 + 0] = floatBitsToUint(x); + context.packedArray1[i4 + 1] = floatBitsToUint(y); + context.packedArray1[i4 + 2] = floatBitsToUint(z); + }, + + setSplatScales(index, scaleX, scaleY, scaleZ) { + const lnScaleMin = LN_SCALE_MIN; + const lnScaleMax = LN_SCALE_MAX; + const lnScaleScale = 1023.0 / (lnScaleMax - lnScaleMin); + const uScaleX = THREE.MathUtils.clamp( + Math.round((Math.log(scaleX) - lnScaleMin) * lnScaleScale), + 0, + 1023, + ); + const uScaleY = THREE.MathUtils.clamp( + Math.round((Math.log(scaleY) - lnScaleMin) * lnScaleScale), + 0, + 1023, + ); + const uScaleZ = THREE.MathUtils.clamp( + Math.round((Math.log(scaleZ) - lnScaleMin) * lnScaleScale), + 0, + 1023, + ); + + const i4 = index * 4; + context.packedArray1[i4 + 3] = + uScaleX | (uScaleY << 10) | (uScaleZ << 20); + }, + + setSplatQuat(index, quatX, quatY, quatZ, quatW) { + const uQuat = encodeQuatOctXy88R8( + tempQuaternion.set(quatX, quatY, quatZ, quatW), + ); + const uQuatX = uQuat & 0xff; + const uQuatY = (uQuat >>> 8) & 0xff; + const uQuatZ = (uQuat >>> 16) & 0xff; + + const i4 = index * 4; + context.packedArray2[i4 + 0] = uQuatX | (uQuatY << 8) | (uQuatZ << 16); + }, + + setSplatRgba(index, r, g, b, a) { + // FIXME: Extended range + const uR = floatToUint8(r); + const uG = floatToUint8(g); + const uB = floatToUint8(b); + const uA = floatToUint8(a); + const i4 = index * 4; + context.packedArray2[i4 + 1] = uR | (uG << 8) | (uB << 16) | (uA << 24); + }, + + setSplatRgb(index, r, g, b) { + // FIXME: Extended range + const uR = floatToUint8(r); + const uG = floatToUint8(g); + const uB = floatToUint8(b); + + const i4 = index * 4; + context.packedArray2[i4 + 1] = + uR | + (uG << 8) | + (uB << 16) | + (context.packedArray2[i4 + 1] & 0xff000000); + }, + + setSplatAlpha(index, a) { + const uA = floatToUint8(a); + + const i4 = index * 4; + context.packedArray2[i4 + 1] = + (context.packedArray2[i4 + 1] & 0x00ffffff) | (uA << 24); + }, + + setSplatSh(index, sh) { + if (!context.packedShArray) { + throw new Error( + "No array for spherical harmonics has been allocated", + ); + } + + const stride = context.numSh * 16; + const startIndex = index * stride; + for (let i = 0; i < SH_DEGREE_TO_NUM_COEFF[context.numSh]; i++) { + context.packedShArray[startIndex + i] = Math.max( + -127, + Math.min(127, sh[i] * 127), + ); + } + }, + + closeTransferable() { + return context; + }, + + close() { + return ExtendedSplats.fromTransferable(context); + }, + }; + } + + static fromTransferable(context: EncodedExtendedSplats) { + return new ExtendedSplats( + context.packedArray1, + context.packedArray2, + context.packedShArray, + { + numSplats: context.numSplats, + numSh: context.numSh, + }, + ); + } +} + +export type EncodedExtendedSplats = { + numSplats: number; + maxSplats: number; + numSh: number; + packedArray1: Uint32Array; + packedArray2: Uint32Array; + packedShArray: Uint8Array | null; +}; + +const tempQuaternion = new THREE.Quaternion(); +const EMPTY_UINT32_ARRAY = new Uint32Array(0); diff --git a/src/encoding/PackedSplats.ts b/src/encoding/PackedSplats.ts new file mode 100644 index 0000000..a957bd1 --- /dev/null +++ b/src/encoding/PackedSplats.ts @@ -0,0 +1,549 @@ +import * as THREE from "three"; +import type { IterableSplatData, SplatCallback, SplatData } from "../Splat"; +import { + LN_SCALE_MAX, + LN_SCALE_MIN, + SCALE_ZERO, + SH_DEGREE_TO_NUM_COEFF, + SPLAT_TEX_HEIGHT, + SPLAT_TEX_WIDTH, +} from "../defines"; +import { + computeMaxSplats, + decodeQuatOctXy88R8, + encodeQuatOctXy88R8, + floatToUint8, + fromHalf, + getTextureSize, + toHalf, +} from "../utils"; +import type { ResizableSplatEncoder } from "./encoder"; + +export type SplatEncoding = { + rgbMin?: number; + rgbMax?: number; + lnScaleMin?: number; + lnScaleMax?: number; + sh1Min?: number; + sh1Max?: number; + sh2Min?: number; + sh2Max?: number; + sh3Min?: number; + sh3Max?: number; +}; + +export const DEFAULT_SPLAT_ENCODING: SplatEncoding = { + rgbMin: 0, + rgbMax: 1, + lnScaleMin: LN_SCALE_MIN, + lnScaleMax: LN_SCALE_MAX, + sh1Min: -1, + sh1Max: 1, + sh2Min: -1, + sh2Max: 1, + sh3Min: -1, + sh3Max: 1, +}; + +export type PackedSplatsOptions = { + // Reserve space for at least this many splats when constructing the collection + // initially. The array will automatically resize past maxSplats so setting it is + // an optional optimization. (default: 0) + maxSplats?: number; + // Override number of splats in packed array to use only a subset. + // (default: length of packed array / 4) + numSplats?: number; + numSh?: number; + // Override the default splat encoding ranges for the PackedSplats. + // (default: undefined) + splatEncoding?: SplatEncoding; +}; + +export class PackedSplats implements IterableSplatData { + maxSplats = 0; + numSplats = 0; + numSh = 0; + packedArray: Uint32Array; + private shArray: Uint8Array | null = null; + readonly splatEncoding?: SplatEncoding; + + private texture: THREE.DataArrayTexture | null = null; + private shTexture: THREE.DataArrayTexture | null = null; + private needsUpdate = false; + + constructor( + packedArray: Uint32Array, + shArray?: Uint8Array | null, + options?: PackedSplatsOptions, + ) { + this.packedArray = packedArray; + this.shArray = shArray ?? null; + // Calculate number of horizontal texture rows that could fit in array. + // A properly initialized packedArray should already take into account the + // width and height of the texture and be rounded up with padding. + this.maxSplats = Math.floor(this.packedArray.length / 4); + this.maxSplats = + Math.floor(this.maxSplats / SPLAT_TEX_WIDTH) * SPLAT_TEX_WIDTH; + this.numSplats = Math.min( + this.maxSplats, + options?.numSplats ?? Number.POSITIVE_INFINITY, + ); + this.numSh = options?.numSh ?? 0; + this.splatEncoding = options?.splatEncoding ?? DEFAULT_SPLAT_ENCODING; + } + + setupMaterial(material: THREE.ShaderMaterial) { + material.defines.USE_PACKED_SPLAT = true; + material.defines.SPLAT_DECODE_FN = "decodePackedSplatDefault"; + material.defines.SPLAT_SH_DECODE_FN = "decodePackedSphericalHarmonics"; + material.defines.NUM_PACKED_SH = this.numSh; + + if (!material.uniforms.packedSplats) { + material.uniforms.packedSplats = { value: null }; + material.uniforms.packedShTexture = { value: null }; + material.uniforms.rgbMinMaxLnScaleMinMax = { value: new THREE.Vector4() }; + } + material.uniforms.packedSplats.value = this.getTexture(); + material.uniforms.packedShTexture.value = this.getShTexture(); + material.uniforms.rgbMinMaxLnScaleMinMax.value.set( + this.splatEncoding?.rgbMin ?? 0.0, + this.splatEncoding?.rgbMax ?? 1.0, + this.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, + this.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, + ); + } + + getTexture(): THREE.DataArrayTexture | null { + if (this.needsUpdate || !this.texture) { + this.needsUpdate = false; + + if (!this.packedArray) { + throw new Error("No packed splats"); + } + + if (this.texture) { + const { width, height, depth } = this.texture.image; + if (this.maxSplats !== width * height * depth) { + // The existing source texture isn't the right size, so dispose it + this.texture.dispose(); + this.texture = null; + } + } + + if (!this.texture) { + // Allocate a new source texture of the right size + const { width, height, depth } = getTextureSize(this.maxSplats); + this.texture = new THREE.DataArrayTexture( + this.packedArray, + width, + height, + depth, + ); + this.texture.format = THREE.RGBAIntegerFormat; + this.texture.type = THREE.UnsignedIntType; + this.texture.internalFormat = "RGBA32UI"; + this.texture.needsUpdate = true; + } else if (this.packedArray.buffer !== this.texture.image.data.buffer) { + // The source texture is the right size, update the data + this.texture.image.data = new Uint8Array(this.packedArray.buffer); + } + } + + return this.texture; + } + + getShTexture(): THREE.DataArrayTexture | null { + if (this.needsUpdate || !this.shTexture) { + if (!this.shArray) { + return null; + } + + if (this.shTexture) { + const { width, height, depth } = this.shTexture.image; + if (this.maxSplats !== width * height * depth) { + // The existing source texture isn't the right size, so dispose it + this.shTexture.dispose(); + this.shTexture = null; + } + } + + if (!this.shTexture) { + // Allocate a new source texture of the right size + let { width, height, depth } = getTextureSize(this.maxSplats); + width *= this.numSh; + this.shTexture = new THREE.DataArrayTexture( + new Uint32Array(this.shArray.buffer), + width, + height, + depth, + ); + this.shTexture.format = THREE.RGBAIntegerFormat; + this.shTexture.type = THREE.UnsignedIntType; + this.shTexture.internalFormat = "RGBA32UI"; + this.shTexture.needsUpdate = true; + } else if (this.shArray.buffer !== this.shTexture.image.data.buffer) { + // The source texture is the right size, update the data + this.shTexture.image.data = new Uint32Array(this.shArray.buffer); + } + } + + return this.shTexture; + } + + iterateCenters( + callback: (index: number, x: number, y: number, z: number) => void, + ) { + for (let i = 0; i < this.numSplats; i++) { + const i4 = i * 4; + const word1 = this.packedArray[i4 + 1]; + const word2 = this.packedArray[i4 + 2]; + + callback( + i, + fromHalf(word1 & 0xffff), + fromHalf((word1 >>> 16) & 0xffff), + fromHalf(word2 & 0xffff), + ); + } + } + + iterateSplats(callback: SplatCallback) { + const shCoeffients = SH_DEGREE_TO_NUM_COEFF[this.numSh]; + const sh = this.numSh > 0 ? new Float32Array(shCoeffients) : undefined; + + for (let i = 0; i < this.numSplats; i++) { + const i4 = i * 4; + const word0 = this.packedArray[i4 + 0]; + const word1 = this.packedArray[i4 + 1]; + const word2 = this.packedArray[i4 + 2]; + const word3 = this.packedArray[i4 + 3]; + + const rgbMin = this.splatEncoding?.rgbMin ?? 0.0; + const rgbMax = this.splatEncoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; + const r = rgbMin + ((word0 & 0xff) / 255) * rgbRange; + const g = rgbMin + (((word0 >>> 8) & 0xff) / 255) * rgbRange; + const b = rgbMin + (((word0 >>> 16) & 0xff) / 255) * rgbRange; + const a = ((word0 >>> 24) & 0xff) / 255; + + const x = fromHalf(word1 & 0xffff); + const y = fromHalf((word1 >>> 16) & 0xffff); + const z = fromHalf(word2 & 0xffff); + + const lnScaleMin = this.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN; + const lnScaleMax = this.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX; + const lnScaleScale = (lnScaleMax - lnScaleMin) / 254.0; + const uScalesX = word3 & 0xff; + const scaleX = + uScalesX === 0 + ? 0.0 + : Math.exp(lnScaleMin + (uScalesX - 1) * lnScaleScale); + const uScalesY = (word3 >>> 8) & 0xff; + const scaleY = + uScalesY === 0 + ? 0.0 + : Math.exp(lnScaleMin + (uScalesY - 1) * lnScaleScale); + const uScalesZ = (word3 >>> 16) & 0xff; + const scaleZ = + uScalesZ === 0 + ? 0.0 + : Math.exp(lnScaleMin + (uScalesZ - 1) * lnScaleScale); + + const uQuat = ((word2 >>> 16) & 0xffff) | ((word3 >>> 8) & 0xff0000); + decodeQuatOctXy88R8(uQuat, tempQuaternion); + + if (sh && this.shArray) { + for (let j = 0; j < shCoeffients; j++) { + sh[j] = (this.shArray[i * shCoeffients + j] - 127) / 127; + } + } + + callback( + i, + fromHalf(word1 & 0xffff), + fromHalf((word1 >>> 16) & 0xffff), + fromHalf(word2 & 0xffff), + scaleX, + scaleY, + scaleZ, + tempQuaternion.x, + tempQuaternion.y, + tempQuaternion.z, + tempQuaternion.w, + a, + r, + g, + b, + sh, + ); + } + } + + dispose() { + if (this.texture) { + this.texture.dispose(); + this.texture.source.data = null; + } + if (this.shTexture) { + this.shTexture.dispose(); + this.shTexture.source.data = null; + } + this.packedArray = EMPTY_UINT32_ARRAY; + this.shArray = null; + this.numSplats = -1; + this.maxSplats = 0; + } + + static encodingName = "packed"; + + static createSplatEncoder( + encoding: SplatEncoding = DEFAULT_SPLAT_ENCODING, + ): ResizableSplatEncoder { + const context: EncodedPackedSplats = { + numSplats: -1, + maxSplats: 0, + numSh: 0, + packedArray: new Uint32Array(), + shArray: null, + }; + // Keep track of the head when pushing splats + let head = 0; + + return { + allocate(numSplats: number, numShBands: number) { + if (context.numSplats !== -1) { + throw new Error("Storage already allocated"); + } + + context.numSplats = numSplats; + context.maxSplats = computeMaxSplats(numSplats); + context.numSh = numShBands; + context.packedArray = new Uint32Array( + context.packedArray.buffer.transfer(context.maxSplats * 16), + ); + + // Allocate one RGBA32UI pixel per numShBands. + // Each pixels can hold 16 sint8 coefficients, so + // 1 band => 9 coefficients (<16) + // 2 bands => 21 coefficients (<32) + // 3 bands => 45 coefficients (<48) + if (numShBands >= 1) + context.shArray = new Uint8Array(context.maxSplats * 16 * numShBands); + }, + + setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ) { + this.setSplatCenter(i, x, y, z); + this.setSplatScales(i, scaleX, scaleY, scaleZ); + this.setSplatQuat(i, quatX, quatY, quatZ, quatW); + this.setSplatRgba(i, r, g, b, opacity); + }, + + pushSplat( + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ) { + const i = head++; + context.numSplats = head; + if (head > context.maxSplats) { + // Resize + context.maxSplats = computeMaxSplats( + Math.max(context.maxSplats, 1) * 2, + ); + context.packedArray = new Uint32Array( + context.packedArray.buffer.transfer(context.maxSplats * 16), + ); + if (context.shArray) { + context.shArray = new Uint8Array( + context.shArray.buffer.transfer( + context.maxSplats * 16 * context.numSh, + ), + ); + } + } + this.setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ); + }, + + setSplatCenter(index, x, y, z) { + const uCenterX = toHalf(x); + const uCenterY = toHalf(y); + const uCenterZ = toHalf(z); + + const i4 = index * 4; + context.packedArray[i4 + 1] = uCenterX | (uCenterY << 16); + context.packedArray[i4 + 2] = + uCenterZ | (context.packedArray[i4 + 2] & 0xffff0000); + }, + + setSplatScales(index, scaleX, scaleY, scaleZ) { + // Allow scales below LN_SCALE_MIN to be encoded as 0, which signifies a 2DGS + const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; + const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; + const lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); + const uScaleX = + scaleX < SCALE_ZERO + ? 0 + : THREE.MathUtils.clamp( + Math.round((Math.log(scaleX) - lnScaleMin) * lnScaleScale) + 1, + 1, + 255, + ); + const uScaleY = + scaleY < SCALE_ZERO + ? 0 + : THREE.MathUtils.clamp( + Math.round((Math.log(scaleY) - lnScaleMin) * lnScaleScale) + 1, + 1, + 255, + ); + const uScaleZ = + scaleZ < SCALE_ZERO + ? 0 + : THREE.MathUtils.clamp( + Math.round((Math.log(scaleZ) - lnScaleMin) * lnScaleScale) + 1, + 1, + 255, + ); + + const i4 = index * 4; + context.packedArray[i4 + 3] = + uScaleX | + (uScaleY << 8) | + (uScaleZ << 16) | + (context.packedArray[i4 + 3] & 0xff000000); + }, + + setSplatQuat(index, quatX, quatY, quatZ, quatW) { + const uQuat = encodeQuatOctXy88R8( + tempQuaternion.set(quatX, quatY, quatZ, quatW), + ); + const uQuatX = uQuat & 0xff; + const uQuatY = (uQuat >>> 8) & 0xff; + const uQuatZ = (uQuat >>> 16) & 0xff; + + const i4 = index * 4; + context.packedArray[i4 + 2] = + (context.packedArray[i4 + 2] & 0x0000ffff) | + (uQuatX << 16) | + (uQuatY << 24); + context.packedArray[i4 + 3] = + (context.packedArray[i4 + 3] & 0x00ffffff) | (uQuatZ << 24); + }, + + setSplatRgba(index, r, g, b, a) { + const rgbMin = encoding?.rgbMin ?? 0.0; + const rgbMax = encoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; + const uR = floatToUint8((r - rgbMin) / rgbRange); + const uG = floatToUint8((g - rgbMin) / rgbRange); + const uB = floatToUint8((b - rgbMin) / rgbRange); + const uA = floatToUint8(a); + const i4 = index * 4; + context.packedArray[i4] = uR | (uG << 8) | (uB << 16) | (uA << 24); + }, + + setSplatRgb(index, r, g, b) { + const rgbMin = encoding?.rgbMin ?? 0.0; + const rgbMax = encoding?.rgbMax ?? 1.0; + const rgbRange = rgbMax - rgbMin; + const uR = floatToUint8((r - rgbMin) / rgbRange); + const uG = floatToUint8((g - rgbMin) / rgbRange); + const uB = floatToUint8((b - rgbMin) / rgbRange); + + const i4 = index * 4; + context.packedArray[i4] = + uR | (uG << 8) | (uB << 16) | (context.packedArray[i4] & 0xff000000); + }, + + setSplatAlpha(index, a) { + const uA = floatToUint8(a); + + const i4 = index * 4; + context.packedArray[i4] = + (context.packedArray[i4] & 0x00ffffff) | (uA << 24); + }, + + setSplatSh(index, sh) { + if (context.shArray) { + const stride = context.numSh * 16; + const startIndex = index * stride; + for (let i = 0; i < SH_DEGREE_TO_NUM_COEFF[context.numSh]; i++) { + context.shArray[startIndex + i] = Math.max( + -127, + Math.min(127, sh[i] * 127), + ); + } + } + }, + + closeTransferable() { + return context; + }, + + close() { + return PackedSplats.fromTransferable(context); + }, + }; + } + + static fromTransferable(context: EncodedPackedSplats) { + return new PackedSplats(context.packedArray, context.shArray, { + numSplats: context.numSplats, + numSh: context.numSh, + }); + } +} + +export type EncodedPackedSplats = { + numSplats: number; + maxSplats: number; + numSh: number; + packedArray: Uint32Array; + shArray: Uint8Array | null; +}; + +const tempQuaternion = new THREE.Quaternion(); +const EMPTY_UINT32_ARRAY = new Uint32Array(0); diff --git a/src/encoding/encoder.ts b/src/encoding/encoder.ts new file mode 100644 index 0000000..d9dcd2d --- /dev/null +++ b/src/encoding/encoder.ts @@ -0,0 +1,122 @@ +import type { SplatData } from "../Splat"; +import { ExtendedSplats } from "./ExtendedSplats"; +import { PackedSplats } from "./PackedSplats"; + +/** + * Interface for encoding raw splat values into a specific encoding. + * Used during loading and when procedurally generating splats. + */ +export interface SplatEncoder { + /** + * Ensures that there is enough space allocated for a given amount + * of splats and spherical harmonics. Should be called before + * setting individual splat values. + * @param numSplats The number of splats to hold + * @param numShBands The number of spherical harmonics + */ + allocate(numSplats: number, numShBands: number): void; + + setSplat( + i: number, + x: number, + y: number, + z: number, + scaleX: number, + scaleY: number, + scaleZ: number, + quatX: number, + quatY: number, + quatZ: number, + quatW: number, + opacity: number, + r: number, + g: number, + b: number, + ): void; + + setSplatCenter(i: number, x: number, y: number, z: number): void; + setSplatScales( + i: number, + scaleX: number, + scaleY: number, + scaleZ: number, + ): void; + setSplatQuat( + i: number, + quatX: number, + quatY: number, + quatZ: number, + quatW: number, + ): void; + setSplatRgba(i: number, r: number, g: number, b: number, a: number): void; + setSplatRgb(i: number, r: number, g: number, b: number): void; + setSplatAlpha(i: number, a: number): void; + + setSplatSh(i: number, sh: ArrayLike): void; + + /** + * Finalizes the splat encoding and returns the encoded result in a transferable representation. + */ + closeTransferable(): T; + + /** + * Finalizes the splat encoding and returns the SplatData. + */ + close(): SplatData; +} + +/** + * Specialized splat encoder type that supports dynamically growing + * the amount of splats. + */ +export interface ResizableSplatEncoder extends SplatEncoder { + pushSplat( + x: number, + y: number, + z: number, + scaleX: number, + scaleY: number, + scaleZ: number, + quatX: number, + quatY: number, + quatZ: number, + quatW: number, + opacity: number, + r: number, + g: number, + b: number, + ): void; +} + +export type UnpackResult = { + unpacked: T; + numSplats: number; +}; + +// biome-ignore lint: Generic is used to constraint transferable type +export type SplatEncodingClass = { + encodingName: string; + createSplatEncoder: ( + options?: Record, + ) => ResizableSplatEncoder; + fromTransferable: (transferable: T) => SplatData; +}; + +export function createSplatEncoder( + name: string, + options?: Record, +): SplatEncoder { + switch (name) { + case "packed": + return PackedSplats.createSplatEncoder(options); + case "extended": + return ExtendedSplats.createSplatEncoder(); + default: + throw new Error(`Unknown splat encoding: ${name}`); + } +} + +/** + * The default splat encoding to use when loading or creating SplatData. + */ +export const DefaultSplatEncoding: SplatEncodingClass = ExtendedSplats; // TODO: Make configurable? diff --git a/src/formats/antisplat.ts b/src/formats/antisplat.ts new file mode 100644 index 0000000..ecfbc91 --- /dev/null +++ b/src/formats/antisplat.ts @@ -0,0 +1,52 @@ +import type { SplatEncoder, UnpackResult } from "../encoding/encoder"; + +export function unpackAntiSplat( + fileBytes: Uint8Array, + splatEncoder: SplatEncoder, +): UnpackResult { + const numSplats = Math.floor(fileBytes.length / 32); // 32 bytes per splat + if (numSplats * 32 !== fileBytes.length) { + throw new Error("Invalid .splat file size"); + } + + splatEncoder.allocate(numSplats, 0); + + const f32 = new Float32Array(fileBytes.buffer); + for (let i = 0; i < numSplats; ++i) { + const i32 = i * 32; + const i8 = i * 8; + const x = f32[i8 + 0]; + const y = f32[i8 + 1]; + const z = f32[i8 + 2]; + const scaleX = f32[i8 + 3]; + const scaleY = f32[i8 + 4]; + const scaleZ = f32[i8 + 5]; + const r = fileBytes[i32 + 24] / 255; + const g = fileBytes[i32 + 25] / 255; + const b = fileBytes[i32 + 26] / 255; + const opacity = fileBytes[i32 + 27] / 255; + const quatW = (fileBytes[i32 + 28] - 128) / 128; + const quatX = (fileBytes[i32 + 29] - 128) / 128; + const quatY = (fileBytes[i32 + 30] - 128) / 128; + const quatZ = (fileBytes[i32 + 31] - 128) / 128; + splatEncoder.setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ); + } + + return { unpacked: splatEncoder.closeTransferable(), numSplats }; +} diff --git a/src/formats/ksplat.ts b/src/formats/ksplat.ts new file mode 100644 index 0000000..aded2c5 --- /dev/null +++ b/src/formats/ksplat.ts @@ -0,0 +1,303 @@ +import { SH_DEGREE_TO_NUM_COEFF } from "../defines"; +import type { SplatEncoder, UnpackResult } from "../encoding/encoder"; +import { fromHalf } from "../utils"; + +type KsplatCompression = { + bytesPerCenter: number; + bytesPerScale: number; + bytesPerRotation: number; + bytesPerColor: number; + bytesPerSphericalHarmonicsComponent: number; + scaleOffsetBytes: number; + rotationOffsetBytes: number; + colorOffsetBytes: number; + sphericalHarmonicsOffsetBytes: number; + scaleRange: number; +}; + +const KSPLAT_COMPRESSION: Record = { + 0: { + bytesPerCenter: 12, + bytesPerScale: 12, + bytesPerRotation: 16, + bytesPerColor: 4, + bytesPerSphericalHarmonicsComponent: 4, + scaleOffsetBytes: 12, + rotationOffsetBytes: 24, + colorOffsetBytes: 40, + sphericalHarmonicsOffsetBytes: 44, + scaleRange: 1, + }, + 1: { + bytesPerCenter: 6, + bytesPerScale: 6, + bytesPerRotation: 8, + bytesPerColor: 4, + bytesPerSphericalHarmonicsComponent: 2, + scaleOffsetBytes: 6, + rotationOffsetBytes: 12, + colorOffsetBytes: 20, + sphericalHarmonicsOffsetBytes: 24, + scaleRange: 32767, + }, + 2: { + bytesPerCenter: 6, + bytesPerScale: 6, + bytesPerRotation: 8, + bytesPerColor: 4, + bytesPerSphericalHarmonicsComponent: 1, + scaleOffsetBytes: 6, + rotationOffsetBytes: 12, + colorOffsetBytes: 20, + sphericalHarmonicsOffsetBytes: 24, + scaleRange: 32767, + }, +}; + +export function unpackKsplat( + fileBytes: Uint8Array, + splatEncoder: SplatEncoder, +): UnpackResult { + const HEADER_BYTES = 4096; + const SECTION_BYTES = 1024; + + let headerOffset = 0; + const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES); + headerOffset += HEADER_BYTES; + + const versionMajor = header.getUint8(0); + const versionMinor = header.getUint8(1); + if (versionMajor !== 0 || versionMinor < 1) { + throw new Error( + `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`, + ); + } + const maxSectionCount = header.getUint32(4, true); + // const sectionCount = header.getUint32(8, true); + // const maxSplatCount = header.getUint32(12, true); + const splatCount = header.getUint32(16, true); + const compressionLevel = header.getUint16(20, true); + if (compressionLevel < 0 || compressionLevel > 2) { + throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`); + } + // const sceneCenterX = header.getFloat32(24, true); + // const sceneCenterY = header.getFloat32(28, true); + // const sceneCenterZ = header.getFloat32(32, true); + const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5; + const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5; + + let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES; + + for (let sectionIndex = 0; sectionIndex < maxSectionCount; ++sectionIndex) { + const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES); + headerOffset += SECTION_BYTES; + + const sectionSplatCount = section.getUint32(0, true); + const sectionMaxSplatCount = section.getUint32(4, true); + const bucketSize = section.getUint32(8, true); + const bucketCount = section.getUint32(12, true); + const bucketBlockSize = section.getFloat32(16, true); + const bucketStorageSizeBytes = section.getUint16(20, true); + const compressionScaleRange = + (section.getUint32(24, true) || + KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ?? + 1; + const fullBucketCount = section.getUint32(32, true); + const fullBucketSplats = fullBucketCount * bucketSize; + const partiallyFilledBucketCount = section.getUint32(36, true); + const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4; + const bucketsStorageSizeBytes = + bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes; + const sphericalHarmonicsDegree = section.getUint16(40, true); + const shComponents = SH_DEGREE_TO_NUM_COEFF[sphericalHarmonicsDegree]; + + // SH degrees are known, allocate splats + if (sectionIndex === 0) { + splatEncoder.allocate(splatCount, sphericalHarmonicsDegree); + } + + const { + bytesPerCenter, + bytesPerScale, + bytesPerRotation, + bytesPerColor, + bytesPerSphericalHarmonicsComponent, + scaleOffsetBytes, + rotationOffsetBytes, + colorOffsetBytes, + sphericalHarmonicsOffsetBytes, + } = KSPLAT_COMPRESSION[compressionLevel]; + const bytesPerSplat = + bytesPerCenter + + bytesPerScale + + bytesPerRotation + + bytesPerColor + + shComponents * bytesPerSphericalHarmonicsComponent; + const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount; + const storageSizeBytes = + splatDataStorageSizeBytes + bucketsStorageSizeBytes; + + const shIndex = [ + // Sh1 + 0, 3, 6, 1, 4, 7, 2, 5, 8, + // Sh2 + 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23, + // Sh3 + 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43, + 30, 37, 44, + ].slice(0, shComponents); + const sh = new Float32Array(shComponents); + + const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange; + const bucketsBase = sectionBase + bucketsMetaDataSizeBytes; + const dataBase = sectionBase + bucketsStorageSizeBytes; + const data = new DataView( + fileBytes.buffer, + dataBase, + splatDataStorageSizeBytes, + ); + const bucketArray = new Float32Array( + fileBytes.buffer, + bucketsBase, + bucketCount * 3, + ); + const partiallyFilledBucketLengths = new Uint32Array( + fileBytes.buffer, + sectionBase, + partiallyFilledBucketCount, + ); + + function getSh(splatOffset: number, component: number) { + if (compressionLevel === 0) { + return data.getFloat32( + splatOffset + sphericalHarmonicsOffsetBytes + component * 4, + true, + ); + } + if (compressionLevel === 1) { + return fromHalf( + data.getUint16( + splatOffset + sphericalHarmonicsOffsetBytes + component * 2, + true, + ), + ); + } + const t = + data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) / + 255; + return ( + minSphericalHarmonicsCoeff + + t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff) + ); + } + + let partialBucketIndex = fullBucketCount; + let partialBucketBase = fullBucketSplats; + + for (let i = 0; i < sectionSplatCount; ++i) { + const splatOffset = i * bytesPerSplat; + + let bucketIndex: number; + if (i < fullBucketSplats) { + bucketIndex = Math.floor(i / bucketSize); + } else { + const bucketLength = + partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount]; + if (i >= partialBucketBase + bucketLength) { + partialBucketIndex += 1; + partialBucketBase += bucketLength; + } + bucketIndex = partialBucketIndex; + } + + const x = + compressionLevel === 0 + ? data.getFloat32(splatOffset + 0, true) + : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) * + compressionScaleFactor + + bucketArray[3 * bucketIndex + 0]; + const y = + compressionLevel === 0 + ? data.getFloat32(splatOffset + 4, true) + : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) * + compressionScaleFactor + + bucketArray[3 * bucketIndex + 1]; + const z = + compressionLevel === 0 + ? data.getFloat32(splatOffset + 8, true) + : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) * + compressionScaleFactor + + bucketArray[3 * bucketIndex + 2]; + + const scaleX = + compressionLevel === 0 + ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true) + : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true)); + const scaleY = + compressionLevel === 0 + ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true) + : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true)); + const scaleZ = + compressionLevel === 0 + ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true) + : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true)); + + const quatW = + compressionLevel === 0 + ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true) + : fromHalf( + data.getUint16(splatOffset + rotationOffsetBytes + 0, true), + ); + const quatX = + compressionLevel === 0 + ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true) + : fromHalf( + data.getUint16(splatOffset + rotationOffsetBytes + 2, true), + ); + const quatY = + compressionLevel === 0 + ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true) + : fromHalf( + data.getUint16(splatOffset + rotationOffsetBytes + 4, true), + ); + const quatZ = + compressionLevel === 0 + ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true) + : fromHalf( + data.getUint16(splatOffset + rotationOffsetBytes + 6, true), + ); + + const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255; + const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255; + const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255; + const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255; + + splatEncoder.setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ); + + if (sphericalHarmonicsDegree >= 1) { + for (const [i, key] of shIndex.entries()) { + sh[i] = getSh(splatOffset, key); + } + splatEncoder.setSplatSh(i, sh); + } + } + sectionBase += storageSizeBytes; + } + return { unpacked: splatEncoder.closeTransferable(), numSplats: splatCount }; +} diff --git a/src/pcsogs.ts b/src/formats/pcsogs.ts similarity index 74% rename from src/pcsogs.ts rename to src/formats/pcsogs.ts index f82230f..d899008 100644 --- a/src/pcsogs.ts +++ b/src/formats/pcsogs.ts @@ -1,30 +1,72 @@ import { unzip } from "fflate"; -import type { SplatEncoding } from "./PackedSplats"; -import { - type PcSogsJson, - type PcSogsV2Json, - tryPcSogsZip, -} from "./SplatLoader"; -import { - computeMaxSplats, - encodeSh1Rgb, - encodeSh2Rgb, - encodeSh3Rgb, - setPackedSplatCenter, - setPackedSplatQuat, - setPackedSplatRgba, - setPackedSplatScales, -} from "./utils"; - -export async function unpackPcSogs( +import { tryPcSogsZip } from "../SplatLoader"; +import { NUM_COEFF_TO_SH_DEGREE, SH_C0 } from "../defines"; +import type { SplatEncoder, UnpackResult } from "../encoding/encoder"; + +export type PcSogsJson = { + means: { + shape: number[]; + dtype: string; + mins: number[]; + maxs: number[]; + files: string[]; + }; + scales: { + shape: number[]; + dtype: string; + mins: number[]; + maxs: number[]; + files: string[]; + }; + quats: { shape: number[]; dtype: string; encoding?: string; files: string[] }; + sh0: { + shape: number[]; + dtype: string; + mins: number[]; + maxs: number[]; + files: string[]; + }; + shN?: { + shape: number[]; + dtype: string; + mins: number; + maxs: number; + quantization: number; + files: string[]; + }; +}; + +export type PcSogsV2Json = { + version: 2; + count: number; + antialias?: boolean; + means: { + mins: number[]; + maxs: number[]; + files: string[]; + }; + scales: { + codebook: number[]; + files: string[]; + }; + quats: { files: string[] }; + sh0: { + codebook: number[]; + files: string[]; + }; + shN?: { + count: number; + bands: number; + codebook: number[]; + files: string[]; + }; +}; + +export async function unpackPcSogs( json: PcSogsJson | PcSogsV2Json, extraFiles: Record, - splatEncoding: SplatEncoding, -): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { + splatEncoder: SplatEncoder, +): Promise> { const isVersion2 = "version" in json; if (!isVersion2 && json.quats.encoding !== "quaternion_packed") { @@ -32,9 +74,11 @@ export async function unpackPcSogs( } const numSplats = isVersion2 ? json.count : json.means.shape[0]; - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; + const numShBands = + (isVersion2 + ? json.shN?.bands + : NUM_COEFF_TO_SH_DEGREE[json.shN?.shape[1] ?? 0]) ?? 0; + splatEncoder.allocate(numSplats, numShBands); const meansPromise = Promise.all([ decodeImageRgba(extraFiles[json.means.files[0]]), @@ -54,7 +98,7 @@ export async function unpackPcSogs( x = Math.sign(x) * (Math.exp(Math.abs(x)) - 1); y = Math.sign(y) * (Math.exp(Math.abs(y)) - 1); z = Math.sign(z) * (Math.exp(Math.abs(z)) - 1); - setPackedSplatCenter(packedArray, i, x, y, z); + splatEncoder.setSplatCenter(i, x, y, z); } }); @@ -98,13 +142,11 @@ export async function unpackPcSogs( for (let i = 0; i < numSplats; ++i) { const i4 = i * 4; - setPackedSplatScales( - packedArray, + splatEncoder.setSplatScales( i, xLookup[scales[i4 + 0]], yLookup[scales[i4 + 1]], zLookup[scales[i4 + 2]], - splatEncoding, ); } }, @@ -128,13 +170,12 @@ export async function unpackPcSogs( const quatY = rOrder <= 1 ? r1 : rOrder === 2 ? rr : r2; const quatZ = rOrder <= 2 ? r2 : rr; const quatW = rOrder === 0 ? rr : r0; - setPackedSplatQuat(packedArray, i, quatX, quatY, quatZ, quatW); + splatEncoder.setSplatQuat(i, quatX, quatY, quatZ, quatW); } }, ); const sh0Promise = decodeImageRgba(extraFiles[json.sh0.files[0]]).then( (sh0) => { - const SH_C0 = 0.28209479177387814; let rLookup: number[]; let gLookup: number[]; let bLookup: number[]; @@ -183,14 +224,12 @@ export async function unpackPcSogs( for (let i = 0; i < numSplats; ++i) { const i4 = i * 4; - setPackedSplatRgba( - packedArray, + splatEncoder.setSplatRgba( i, rLookup[sh0[i4 + 0]], gLookup[sh0[i4 + 1]], bLookup[sh0[i4 + 2]], aLookup[sh0[i4 + 3]], - splatEncoding, ); } }, @@ -198,23 +237,8 @@ export async function unpackPcSogs( const promises = [meansPromise, scalesPromise, quatsPromise, sh0Promise]; if (json.shN) { - const useSH3 = isVersion2 - ? json.shN.bands >= 3 - : json.shN.shape[1] >= 48 - 3; - const useSH2 = isVersion2 - ? json.shN.bands >= 2 - : json.shN.shape[1] >= 27 - 3; - const useSH1 = isVersion2 - ? json.shN.bands >= 1 - : json.shN.shape[1] >= 12 - 3; - - if (useSH1) extra.sh1 = new Uint32Array(numSplats * 2); - if (useSH2) extra.sh2 = new Uint32Array(numSplats * 4); - if (useSH3) extra.sh3 = new Uint32Array(numSplats * 4); - - const sh1 = new Float32Array(9); - const sh2 = new Float32Array(15); - const sh3 = new Float32Array(21); + const numCoefficients = [3, 8, 15][numShBands as 1 | 2 | 3]; + const sh = new Float32Array(numCoefficients * 3); const shN = json.shN; const shNPromise = Promise.all([ @@ -235,32 +259,13 @@ export async function unpackPcSogs( const row = label >>> 6; const offset = row * centroids.width + col; - for (let d = 0; d < 3; ++d) { - if (useSH1) { - for (let k = 0; k < 3; ++k) { - sh1[k * 3 + d] = lookup[centroids.rgba[(offset + k) * 4 + d]]; - } - } - - if (useSH2) { - for (let k = 0; k < 5; ++k) { - sh2[k * 3 + d] = lookup[centroids.rgba[(offset + 3 + k) * 4 + d]]; - } - } - - if (useSH3) { - for (let k = 0; k < 7; ++k) { - sh3[k * 3 + d] = lookup[centroids.rgba[(offset + 8 + k) * 4 + d]]; - } + for (let k = 0; k < numCoefficients; ++k) { + for (let d = 0; d < 3; ++d) { + sh[k * 3 + d] = lookup[centroids.rgba[(offset + k) * 4 + d]]; } } - if (useSH1) - encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding); - if (useSH2) - encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding); - if (useSH3) - encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding); + splatEncoder.setSplatSh(i, sh); } }); promises.push(shNPromise); @@ -268,7 +273,7 @@ export async function unpackPcSogs( await Promise.all(promises); - return { packedArray, numSplats, extra }; + return { unpacked: splatEncoder.closeTransferable(), numSplats }; } // WebGL context for reading raw pixel data of WebP images @@ -328,14 +333,10 @@ async function decodeImageRgba(fileBytes: ArrayBuffer) { return rgba; } -export async function unpackPcSogsZip( +export async function unpackPcSogsZip( fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { + splatEncoder: SplatEncoder, +): Promise> { const nameJson = tryPcSogsZip(fileBytes); if (!nameJson) { throw new Error("Invalid PC SOGS zip file"); @@ -358,14 +359,12 @@ export async function unpackPcSogsZip( fileMap.set(prefix + file, file); } - const unzipped = await new Promise>( + const unzipped = await new Promise>( (resolve, reject) => { unzip( fileBytes, { - filter: ({ name }) => { - return fileMap.has(name); - }, + filter: ({ name }) => fileMap.has(name), }, (err, files) => { if (err) { @@ -380,8 +379,8 @@ export async function unpackPcSogsZip( const extraFiles: Record = {}; for (const [full, name] of fileMap.entries()) { - extraFiles[name] = unzipped[full]; + extraFiles[name] = unzipped[full].buffer as ArrayBuffer; } - return await unpackPcSogs(json, extraFiles, splatEncoding); + return await unpackPcSogs(json, extraFiles, splatEncoder); } diff --git a/src/ply.ts b/src/formats/ply.ts similarity index 91% rename from src/ply.ts rename to src/formats/ply.ts index 1946888..cfedd65 100644 --- a/src/ply.ts +++ b/src/formats/ply.ts @@ -1,6 +1,12 @@ // PLY file format reader -import { USE_COMPILED_PARSER_FUNCTION } from "./defines"; +import { + NUM_COEFF_TO_SH_DEGREE, + SH_C0, + SH_DEGREE_TO_NUM_COEFF, + USE_COMPILED_PARSER_FUNCTION, +} from "../defines"; +import type { SplatEncoder, UnpackResult } from "../encoding/encoder"; const PLY_PROPERTY_TYPES = [ "char", @@ -26,32 +32,19 @@ export type PlyProperty = { countType?: PlyPropertyType; }; -// Callback for parseSplats base Gsplat data -export type SplatCallback = ( - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, -) => void; +export async function unpackPly( + fileBytes: Uint8Array, + splatEncoder: SplatEncoder, +): Promise> { + const ply = new PlyReader({ fileBytes }); + await ply.parseHeader(); + const numSplats = ply.numSplats; + splatEncoder.allocate(numSplats, ply.numShBands); -// Callback for parseSplats SH coefficients -export type SplatShCallback = ( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, -) => void; + ply.parseSplats(splatEncoder); + + return { unpacked: splatEncoder.closeTransferable(), numSplats }; +} // A PlyReader is used to parse PLY files for Gsplat data. // It takes a Uint8Array/ArrayBuffer as input fileBytes, parses the text header, @@ -68,6 +61,7 @@ export class PlyReader { static defaultPointScale = 0.001; numSplats = 0; + numShBands = 0; // Create a PlyReader from a Uint8Array/ArrayBuffer, no parsing done yet constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) { @@ -80,7 +74,9 @@ export class PlyReader { // "vertex" contains the Gsplat data. async parseHeader() { const bufferStream = new ReadableStream({ - start: (controller: ReadableStreamController) => { + start: ( + controller: ReadableStreamController>, + ) => { // Assume the header is less than 64KB controller.enqueue(this.fileBytes.slice(0, 65536)); controller.close(); @@ -183,6 +179,10 @@ export class PlyReader { if (this.elements.vertex) { this.numSplats = this.elements.vertex.count; + this.numShBands = getNumSh(this.elements.vertex.properties); + } + if (this.elements.sh) { + this.numShBands = getNumSh(this.elements.sh.properties); } } @@ -218,7 +218,7 @@ export class PlyReader { // Parse all the Gsplat data in the PLY file in go, invoking the given // callbacks for each Gsplat. - parseSplats(splatCallback: SplatCallback, shCallback?: SplatShCallback) { + parseSplats(splatEncoder: SplatEncoder) { if (this.elements.vertex == null) { throw new Error("No vertex element found"); } @@ -227,32 +227,27 @@ export class PlyReader { const ssChunks: SSChunk[] = []; let numSh = 0; - let sh1Props: number[] = []; - let sh2Props: number[] = []; - let sh3Props: number[] = []; - let sh1: Float32Array | undefined = undefined; - let sh2: Float32Array | undefined = undefined; - let sh3: Float32Array | undefined = undefined; + let shProps: number[] = []; + let sh: Float32Array | undefined = undefined; function prepareSh() { // Prepare SH coefficient names and arrays for numSh total SH levels - const num_f_rest = NUM_SH_TO_NUM_F_REST[numSh]; - sh1Props = new Array(3) + const num_f_rest = SH_DEGREE_TO_NUM_COEFF[numSh]; + const sh1Props = new Array(3) .fill(null) .flatMap((_, k) => [0, 1, 2].map((_, d) => k + (d * num_f_rest) / 3)); - sh2Props = new Array(5) + const sh2Props = new Array(5) .fill(null) .flatMap((_, k) => [0, 1, 2].map((_, d) => 3 + k + (d * num_f_rest) / 3), ); - sh3Props = new Array(7) + const sh3Props = new Array(7) .fill(null) .flatMap((_, k) => [0, 1, 2].map((_, d) => 8 + k + (d * num_f_rest) / 3), ); - sh1 = numSh >= 1 ? new Float32Array(3 * 3) : undefined; - sh2 = numSh >= 2 ? new Float32Array(5 * 3) : undefined; - sh3 = numSh >= 3 ? new Float32Array(7 * 3) : undefined; + shProps = [...sh1Props, ...sh2Props, ...sh3Props]; + sh = numSh >= 1 ? new Float32Array(num_f_rest) : undefined; } function ssShCallback( @@ -260,25 +255,14 @@ export class PlyReader { item: Record, ) { // Decode SH for SuperSplat compressed data - if (!sh1) { - throw new Error("Missing sh1"); + if (!sh) { + throw new Error("Missing sh"); } - const sh = item.f_rest as number[]; - - for (let i = 0; i < sh1Props.length; i++) { - sh1[i] = (sh[sh1Props[i]] * 8) / 255 - 4; - } - if (sh2) { - for (let i = 0; i < sh2Props.length; i++) { - sh2[i] = (sh[sh2Props[i]] * 8) / 255 - 4; - } - } - if (sh3) { - for (let i = 0; i < sh3Props.length; i++) { - sh3[i] = (sh[sh3Props[i]] * 8) / 255 - 4; - } + const fRest = item.f_rest as number[]; + for (let i = 0; i < fRest.length; i++) { + sh[i] = (fRest[shProps[i]] * 8) / 255 - 4; } - shCallback?.(index, sh1, sh2, sh3); + splatEncoder.setSplatSh(index, sh); } function initSuperSplat(element: PlyElement) { @@ -361,7 +345,7 @@ export class PlyReader { function decodeSuperSplat(element: PlyElement) { // Decode SuperSplat compressed data in vertex and sh elements - if (shCallback && element.name === "sh") { + if (element.name === "sh") { numSh = getNumSh(element.properties); prepareSh(); return ssShCallback; @@ -459,7 +443,7 @@ export class PlyReader { (min_b ?? 0); const opacity = (packed_color & 255) / 255; - splatCallback( + splatEncoder.setSplat( index, x, y, @@ -573,7 +557,7 @@ export class PlyReader { ? (item.blue as number) / blueDiv : 1.0; - splatCallback( + splatEncoder.setSplat( index, item.x as number, item.y as number, @@ -591,24 +575,12 @@ export class PlyReader { b, ); - if (shCallback && sh1) { - const sh = item.f_rest as number[]; - if (sh1) { - for (let i = 0; i < sh1Props.length; i++) { - sh1[i] = sh[sh1Props[i]]; - } - } - if (sh2) { - for (let i = 0; i < sh2Props.length; i++) { - sh2[i] = sh[sh2Props[i]]; - } - } - if (sh3) { - for (let i = 0; i < sh3Props.length; i++) { - sh3[i] = sh[sh3Props[i]]; - } + if (sh) { + const fRest = item.f_rest as number[]; + for (let i = 0; i < fRest.length; i++) { + sh[i] = fRest[shProps[i]]; } - shCallback(index, sh1, sh2, sh3); + splatEncoder.setSplatSh(index, sh); } }; }; @@ -714,8 +686,6 @@ export class PlyReader { } } -export const SH_C0 = 0.28209479177387814; - type FieldParser = ( data: DataView, offset: number, @@ -844,19 +814,6 @@ const FIELD_SCALE: Record = { double: 1, }; -const NUM_F_REST_TO_NUM_SH: Record = { - 0: 0, - 9: 1, - 24: 2, - 45: 3, -}; -const NUM_SH_TO_NUM_F_REST: Record = { - 0: 0, - 1: 9, - 2: 24, - 3: 45, -}; - const F_REST_REGEX = /^f_rest_([0-9]{1,2})$/; function createEmptyItem( @@ -1060,7 +1017,7 @@ function getNumSh(properties: Record) { while (properties[`f_rest_${num_f_rest}`]) { num_f_rest += 1; } - const numSh = NUM_F_REST_TO_NUM_SH[num_f_rest]; + const numSh = NUM_COEFF_TO_SH_DEGREE[num_f_rest]; if (numSh == null) { throw new Error(`Unsupported number of SH coefficients: ${num_f_rest}`); } diff --git a/src/formats/spz.ts b/src/formats/spz.ts new file mode 100644 index 0000000..3e078f3 --- /dev/null +++ b/src/formats/spz.ts @@ -0,0 +1,433 @@ +import * as THREE from "three"; +import { SH_C0, SH_DEGREE_TO_NUM_COEFF } from "../defines"; +import type { SplatEncoder, UnpackResult } from "../encoding/encoder"; +import { GunzipReader, fromHalf } from "../utils"; + +export async function unpackSpz( + fileBytes: Uint8Array, + splatEncoder: SplatEncoder, +): Promise> { + const spz = new SpzReader({ fileBytes }); + await spz.parseHeader(); + const numSplats = spz.numSplats; + splatEncoder.allocate(numSplats, spz.shDegree); + + await spz.parseSplats(splatEncoder); + + return { unpacked: splatEncoder.closeTransferable(), numSplats }; +} + +// SPZ file format reader + +export class SpzReader { + private fileBytes: Uint8Array; + private reader: GunzipReader; + + version = -1; + numSplats = 0; + shDegree = 0; + fractionalBits = 0; + flags = 0; + flagAntiAlias = false; + reserved = 0; + private headerParsed = false; + private parsed = false; + + constructor({ + fileBytes, + }: { fileBytes: Uint8Array | ArrayBuffer }) { + this.fileBytes = + fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes; + this.reader = new GunzipReader(this.fileBytes); + } + + async parseHeader() { + if (this.headerParsed) { + throw new Error("SPZ file header already parsed"); + } + + const header = new DataView((await this.reader.read(16)).buffer); + if (header.getUint32(0, true) !== 0x5053474e) { + throw new Error("Invalid SPZ file"); + } + this.version = header.getUint32(4, true); + if (this.version < 1 || this.version > 3) { + throw new Error(`Unsupported SPZ version: ${this.version}`); + } + + this.numSplats = header.getUint32(8, true); + this.shDegree = header.getUint8(12); + this.fractionalBits = header.getUint8(13); + this.flags = header.getUint8(14); + this.flagAntiAlias = (this.flags & 0x01) !== 0; + this.reserved = header.getUint8(15); + this.headerParsed = true; + this.parsed = false; + } + + async parseSplats(splatEncoder: SplatEncoder) { + if (!this.headerParsed) { + throw new Error("SPZ file header must be parsed first"); + } + if (this.parsed) { + throw new Error("SPZ file already parsed"); + } + this.parsed = true; + + if (this.version === 1) { + // float16 centers + const centerBytes = await this.reader.read(this.numSplats * 3 * 2); + const centerUint16 = new Uint16Array(centerBytes.buffer); + for (let i = 0; i < this.numSplats; i++) { + const i3 = i * 3; + const x = fromHalf(centerUint16[i3]); + const y = fromHalf(centerUint16[i3 + 1]); + const z = fromHalf(centerUint16[i3 + 2]); + splatEncoder.setSplatCenter(i, x, y, z); + } + } else if (this.version === 2 || this.version === 3) { + // 24-bit fixed-point centers + const fixed = 1 << this.fractionalBits; + const centerBytes = await this.reader.read(this.numSplats * 3 * 3); + for (let i = 0; i < this.numSplats; i++) { + const i9 = i * 9; + const x = + (((centerBytes[i9 + 2] << 24) | + (centerBytes[i9 + 1] << 16) | + (centerBytes[i9] << 8)) >> + 8) / + fixed; + const y = + (((centerBytes[i9 + 5] << 24) | + (centerBytes[i9 + 4] << 16) | + (centerBytes[i9 + 3] << 8)) >> + 8) / + fixed; + const z = + (((centerBytes[i9 + 8] << 24) | + (centerBytes[i9 + 7] << 16) | + (centerBytes[i9 + 6] << 8)) >> + 8) / + fixed; + splatEncoder.setSplatCenter(i, x, y, z); + } + } else { + throw new Error("Unreachable"); + } + + { + const bytes = await this.reader.read(this.numSplats); + for (let i = 0; i < this.numSplats; i++) { + splatEncoder.setSplatAlpha(i, bytes[i] / 255); + } + } + { + const rgbBytes = await this.reader.read(this.numSplats * 3); + const scale = SH_C0 / 0.15; + for (let i = 0; i < this.numSplats; i++) { + const i3 = i * 3; + const r = (rgbBytes[i3] / 255 - 0.5) * scale + 0.5; + const g = (rgbBytes[i3 + 1] / 255 - 0.5) * scale + 0.5; + const b = (rgbBytes[i3 + 2] / 255 - 0.5) * scale + 0.5; + splatEncoder.setSplatRgb(i, r, g, b); + } + } + { + const scalesBytes = await this.reader.read(this.numSplats * 3); + for (let i = 0; i < this.numSplats; i++) { + const i3 = i * 3; + const scaleX = Math.exp(scalesBytes[i3] / 16 - 10); + const scaleY = Math.exp(scalesBytes[i3 + 1] / 16 - 10); + const scaleZ = Math.exp(scalesBytes[i3 + 2] / 16 - 10); + splatEncoder.setSplatScales(i, scaleX, scaleY, scaleZ); + } + } + if (this.version === 3) { + // Version 3 uses a trick called "smallest three" to compress the rotation quaternions + // achieving better precision. "Optimizing orientation" section at https://gafferongames.com/post/snapshot_compression/ A quaternion length must be 1: x^2+y^2+z^2+w^2 = 1 + // We can drop one component and reconstruct it with the identity above. + // Largest component is dropped for best numerical precision. + // Quaternion stored in 32 bits + // 10 bits singed integer for each of the 3 components + 2 bits indicating the index of dropped component. + // vs 8 bits for each component uncompressed (spz version < 3) + // Max Value after extracting largest component v is another component v + // (v,v,0,0) + // v^2 + v^2 = 1 + // v = 1 / sqrt(2); + const maxValue = 1 / Math.sqrt(2); // 0.7071 + const quatBytes = await this.reader.read(this.numSplats * 4); + for (let i = 0; i < this.numSplats; i++) { + const i3 = i * 4; + const quaternion = [0, 0, 0, 0]; + const values = [ + quatBytes[i3], + quatBytes[i3 + 1], + quatBytes[i3 + 2], + quatBytes[i3 + 3], + ]; + // all values are packed in 32 bits (10 per each of 3 components + 2 bits of index of larged value) + const combinedValues = + values[0] + (values[1] << 8) + (values[2] << 16) + (values[3] << 24); + // each component value is 9 bits + sign (1 bit) + const valueMask = (1 << 9) - 1; + // extract index of the largest element. 2 top bits. + const largestIndex = combinedValues >>> 30; + let remainingValues = combinedValues; + let sumSquares = 0; + + for (let i = 3; i >= 0; --i) { + if (i !== largestIndex) { + // extract current value and sign. + const value = remainingValues & valueMask; + const sign = (remainingValues >>> 9) & 0x1; + // each value is represented as 10 bits. Shift to next one. + remainingValues = remainingValues >>> 10; + // convert to range [0,1] and then to [0, 0.7071] + quaternion[i] = maxValue * (value / valueMask); + // apply sign. + quaternion[i] = sign === 0 ? quaternion[i] : -quaternion[i]; + // accumulate the sum of squares + sumSquares += quaternion[i] * quaternion[i]; + } + } + + // quartenion length must be 1 (x^2+y^2+z^2+w^2 = 1) + // so can reconstruct largest component from the other 3. + // w = sqrt(1 - x^2 - y^2 - z^2); + const square = 1 - sumSquares; + quaternion[largestIndex] = Math.sqrt(Math.max(square, 0)); + + splatEncoder.setSplatQuat( + i, + quaternion[0], + quaternion[1], + quaternion[2], + quaternion[3], + ); + } + } else { + const quatBytes = await this.reader.read(this.numSplats * 3); + for (let i = 0; i < this.numSplats; i++) { + const i3 = i * 3; + const quatX = quatBytes[i3] / 127.5 - 1; + const quatY = quatBytes[i3 + 1] / 127.5 - 1; + const quatZ = quatBytes[i3 + 2] / 127.5 - 1; + const quatW = Math.sqrt( + Math.max(0, 1 - quatX * quatX - quatY * quatY - quatZ * quatZ), + ); + splatEncoder.setSplatQuat(i, quatX, quatY, quatZ, quatW); + } + } + + if (this.shDegree >= 1) { + const shCoefficients = SH_DEGREE_TO_NUM_COEFF[this.shDegree]; + const sh = new Float32Array(shCoefficients); + const shBytes = await this.reader.read(this.numSplats * shCoefficients); + + for (let i = 0; i < this.numSplats; i++) { + for (let j = 0; j < shCoefficients; ++j) { + sh[j] = (shBytes[i * shCoefficients + j] - 128) / 128; + } + splatEncoder.setSplatSh(i, sh); + } + } + } +} + +// SPZ file format writer + +export const SPZ_MAGIC = 0x5053474e; // NGSP = Niantic gaussian splat +export const SPZ_VERSION = 3; +export const FLAG_ANTIALIASED = 0x1; + +export class SpzWriter { + private buffer: ArrayBuffer; + private view: DataView; + private numSplats: number; + readonly shDegree: number; + private fractionalBits: number; + private fraction: number; + private flagAntiAlias: boolean; + clippedCount = 0; + + constructor({ + numSplats, + shDegree, + fractionalBits = 12, + flagAntiAlias = true, + }: { + numSplats: number; + shDegree: number; + fractionalBits?: number; + flagAntiAlias?: boolean; + }) { + const splatSize = + 9 + // Position + 1 + // Opacity + 3 + // Scale + 3 + // DC-rgb + 4 + // Rotation + SH_DEGREE_TO_NUM_COEFF[shDegree]; + const bufferSize = 16 + numSplats * splatSize; + this.buffer = new ArrayBuffer(bufferSize); + this.view = new DataView(this.buffer); + + this.view.setUint32(0, SPZ_MAGIC, true); // NGSP + this.view.setUint32(4, SPZ_VERSION, true); + this.view.setUint32(8, numSplats, true); + this.view.setUint8(12, shDegree); + this.view.setUint8(13, fractionalBits); + this.view.setUint8(14, flagAntiAlias ? FLAG_ANTIALIASED : 0); + this.view.setUint8(15, 0); // Reserved + + this.numSplats = numSplats; + this.shDegree = shDegree; + this.fractionalBits = fractionalBits; + this.fraction = 1 << fractionalBits; + this.flagAntiAlias = flagAntiAlias; + } + + setCenter(index: number, x: number, y: number, z: number) { + // Divide by this.fraction and round to nearest integer, + // then write as 3-bytes per x then y then z. + const xRounded = Math.round(x * this.fraction); + const xInt = Math.max(-0x7fffff, Math.min(0x7fffff, xRounded)); + const yRounded = Math.round(y * this.fraction); + const yInt = Math.max(-0x7fffff, Math.min(0x7fffff, yRounded)); + const zRounded = Math.round(z * this.fraction); + const zInt = Math.max(-0x7fffff, Math.min(0x7fffff, zRounded)); + const clipped = xRounded !== xInt || yRounded !== yInt || zRounded !== zInt; + if (clipped) { + this.clippedCount += 1; + } + const i9 = index * 9; + const base = 16 + i9; + this.view.setUint8(base, xInt & 0xff); + this.view.setUint8(base + 1, (xInt >> 8) & 0xff); + this.view.setUint8(base + 2, (xInt >> 16) & 0xff); + this.view.setUint8(base + 3, yInt & 0xff); + this.view.setUint8(base + 4, (yInt >> 8) & 0xff); + this.view.setUint8(base + 5, (yInt >> 16) & 0xff); + this.view.setUint8(base + 6, zInt & 0xff); + this.view.setUint8(base + 7, (zInt >> 8) & 0xff); + this.view.setUint8(base + 8, (zInt >> 16) & 0xff); + } + + setAlpha(index: number, alpha: number) { + const base = 16 + this.numSplats * 9 + index; + this.view.setUint8( + base, + Math.max(0, Math.min(255, Math.round(alpha * 255))), + ); + } + + static scaleRgb(r: number) { + const v = ((r - 0.5) / (SH_C0 / 0.15) + 0.5) * 255; + return Math.max(0, Math.min(255, Math.round(v))); + } + + setRgb(index: number, r: number, g: number, b: number) { + const base = 16 + this.numSplats * 10 + index * 3; + this.view.setUint8(base, SpzWriter.scaleRgb(r)); + this.view.setUint8(base + 1, SpzWriter.scaleRgb(g)); + this.view.setUint8(base + 2, SpzWriter.scaleRgb(b)); + } + + setScale(index: number, scaleX: number, scaleY: number, scaleZ: number) { + const base = 16 + this.numSplats * 13 + index * 3; + this.view.setUint8( + base, + Math.max(0, Math.min(255, Math.round((Math.log(scaleX) + 10) * 16))), + ); + this.view.setUint8( + base + 1, + Math.max(0, Math.min(255, Math.round((Math.log(scaleY) + 10) * 16))), + ); + this.view.setUint8( + base + 2, + Math.max(0, Math.min(255, Math.round((Math.log(scaleZ) + 10) * 16))), + ); + } + + setQuat( + index: number, + ...q: [number, number, number, number] // x, y, z, w + ) { + const base = 16 + this.numSplats * 16 + index * 4; + + const quat = normalize(q); + + // Find largest component + let iLargest = 0; + for (let i = 1; i < 4; ++i) { + if (Math.abs(quat[i]) > Math.abs(quat[iLargest])) { + iLargest = i; + } + } + + // Since -quat represents the same rotation as quat, transform the quaternion so the largest element + // is positive. This avoids having to send its sign bit. + const negate = quat[iLargest] < 0 ? 1 : 0; + + // Do compression using sign bit and 9-bit precision per element. + let comp = iLargest; + for (let i = 0; i < 4; ++i) { + if (i !== iLargest) { + const negbit = (quat[i] < 0 ? 1 : 0) ^ negate; + const mag = Math.floor( + ((1 << 9) - 1) * (Math.abs(quat[i]) / Math.SQRT1_2) + 0.5, + ); + comp = (comp << 10) | (negbit << 9) | mag; + } + } + + this.view.setUint8(base, comp & 0xff); + this.view.setUint8(base + 1, (comp >> 8) & 0xff); + this.view.setUint8(base + 2, (comp >> 16) & 0xff); + this.view.setUint8(base + 3, (comp >>> 24) & 0xff); + } + + static quantizeSh(sh: number, bits: number) { + const value = Math.round(sh * 128) + 128; + const bucketSize = 1 << (8 - bits); + const quantized = + Math.floor((value + bucketSize / 2) / bucketSize) * bucketSize; + return Math.max(0, Math.min(255, quantized)); + } + + setSh(index: number, sh: ArrayLike) { + const base = + 16 + this.numSplats * 20 + index * SH_DEGREE_TO_NUM_COEFF[this.shDegree]; + for (let i = 0; i < SH_DEGREE_TO_NUM_COEFF[this.shDegree]; ++i) { + this.view.setUint8(base + i, SpzWriter.quantizeSh(sh[i], i >= 9 ? 4 : 5)); + } + } + + async finalize(): Promise { + const input = new Uint8Array(this.buffer); + const stream = new ReadableStream({ + async start(controller) { + controller.enqueue(input); + controller.close(); + }, + }); + const compressed = stream.pipeThrough(new CompressionStream("gzip")); + const response = new Response(compressed); + const buffer = await response.arrayBuffer(); + console.log( + "Compressed", + input.length, + "bytes to", + buffer.byteLength, + "bytes", + ); + return new Uint8Array(buffer); + } +} + +const tempQuat = new THREE.Quaternion(); +function normalize( + quat: [number, number, number, number], +): [number, number, number, number] { + return tempQuat.fromArray(quat).normalize().toArray(quat); +} diff --git a/src/generators.ts b/src/generators.ts deleted file mode 100644 index 97318f2..0000000 --- a/src/generators.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from "./generators/static"; -export * from "./generators/snow"; diff --git a/src/generators/snow.ts b/src/generators/snow.ts deleted file mode 100644 index 9174588..0000000 --- a/src/generators/snow.ts +++ /dev/null @@ -1,276 +0,0 @@ -import * as THREE from "three"; - -import { SplatGenerator, SplatTransformer } from "../SplatGenerator"; -import { - Gsplat, - add, - combine, - combineGsplat, - defineGsplat, - dynoBlock, - dynoConst, - dynoFloat, - dynoLiteral, - fract, - hashVec4, - max, - mix, - mod, - mul, - sin, - split, - sub, - vec3, -} from "../dyno"; -import { dynoVec3 } from "../dyno"; - -// snowBox produces Gsplat trajectories that move in a deterministic fashion over time, -// with high similarity between adjacent frames. See examples/atmospheric/main.js -// for an example that creates a snowBox. - -// A snowBox instance has a collection of properties that can be tuned to achieve -// different particle effects. The below DEFAULT_SNOW and DEFAULT_RAIN are example -// parameter sets that look a lot like snow and rain, and can be used as a starting -// point for further tweaking: `const mySnow = { ...DEFAULT_SNOW, density: 500 };` - -export const DEFAULT_SNOW = { - box: new THREE.Box3( - new THREE.Vector3(-1, -1, -1), - new THREE.Vector3(1, 1, 1), - ), - density: 100, - fallDirection: new THREE.Vector3(-1, -3, 1).normalize(), - fallVelocity: 0.02, - wanderScale: 0.04, - wanderVariance: 2, - color1: new THREE.Color(1, 1, 1), - color2: new THREE.Color(0.5, 0.5, 1), - minScale: 0.001, - maxScale: 0.005, - anisoScale: new THREE.Vector3(1, 1, 1), -}; - -export const DEFAULT_RAIN = { - box: new THREE.Box3( - new THREE.Vector3(-2, -1, -2), - new THREE.Vector3(2, 5, 2), - ), - density: 10, - fallDirection: new THREE.Vector3(0, -1, 0), - fallVelocity: 2, - wanderScale: 0.1, - wanderVariance: 1, - color1: new THREE.Color(1, 1, 1), - color2: new THREE.Color(0.25, 0.25, 0.5), - minScale: 0.005, - maxScale: 0.01, - anisoScale: new THREE.Vector3(0.1, 1, 0.1), -}; - -// Calling snowBox creates a new snowBox instance and returns an object with -// the snowBox itself as well `as` a collection of controls that can be used to -// adjust the snowBox's properties over time: -// -// - snow: the SplatGenerator snowBox instance -// - min: the vec3 uniform of the snowBox minimum position -// - max: the vec3 uniform of the snowBox maximum position -// - minY: the float uniform of the snowBox minimum y-coordinate -// - color1: the vec3 uniform of the snowBox first color -// - color2: the vec3 uniform of the snowBox second color -// - opacity: the float uniform of the snowBox opacity -// - fallVelocity: the float uniform of the snowBox fall velocity -// - wanderVariance: the float uniform of the snowBox wander variance -// - wanderScale: the float uniform of the snowBox wander scale -// - fallDirection: the vec3 uniform of the snowBox fall direction -// - minScale: the float uniform of the snowBox minimum scale -// - maxScale: the float uniform of the snowBox maximum scale -// - anisoScale: the vec3 uniform of the snowBox anisotropic scale - -export function snowBox({ - // min and max box extents of the snowBox - box, - // minimum y-coordinate to clamp particle position, which can be used to - // fake hitting a ground plane and lingering there for a bit - minY, - // number of Gsplats to generate (default: calculated from box and density) - numSplats, - // density of Gsplats per unit volume (default: 100) - density, - // The xyz anisotropic scale of the Gsplat, which can be used for example - // to elongate rain particles (default: (1, 1, 1)) - anisoScale, - // Minimum Gsplat particle scale (default: 0.001) - minScale, - // Maximum Gsplat particle scale (default: 0.005) - maxScale, - // The average direction of fall (default: (0, -1, 0)) - fallDirection, - // The average speed of the fall (multiplied with fallDirection) (default: 0.02) - fallVelocity, - // The world scale of wandering overlay motion (default: 0.01) - wanderScale, - // Controls how uniformly the particles wander in sync, more variance mean - // more randomness in the motion (default: 2) - wanderVariance, - // Color 1 of the two colors interpolated between (default: (1, 1, 1)) - color1, - // Color 2 of the two colors interpolated between (default: (0.5, 0.5, 1)) - color2, - // The base opacity of the Gsplats (default: 1) - opacity, - // Optional callback function to call each frame. - onFrame, -}: { - box?: THREE.Box3; - minY?: number; - numSplats?: number; - density?: number; - anisoScale?: THREE.Vector3; - minScale?: number; - maxScale?: number; - fallDirection?: THREE.Vector3; - fallVelocity?: number; - wanderScale?: number; - wanderVariance?: number; - color1?: THREE.Color; - color2?: THREE.Color; - opacity?: number; - onFrame?: ({ - object, - time, - deltaTime, - }: { object: SplatGenerator; time: number; deltaTime: number }) => void; -}) { - box = - box ?? - new THREE.Box3(new THREE.Vector3(-1, -1, -1), new THREE.Vector3(1, 1, 1)); - const volume = - (box.max.x - box.min.x) * (box.max.y - box.min.y) * (box.max.z - box.min.z); - density = density ?? 100; - numSplats = - numSplats ?? Math.max(1, Math.min(1000000, Math.round(volume * density))); - - const dynoMinScale = dynoFloat(minScale ?? 0.001); - const dynoMaxScale = dynoFloat(maxScale ?? 0.005); - const dynoAnisoScale = dynoVec3( - (anisoScale?.clone() ?? new THREE.Vector3(1, 1, 1)).normalize(), - ); - const dynoFallDirection = dynoVec3( - (fallDirection ?? new THREE.Vector3(0, -1, 0)).normalize(), - ); - const dynoFallVelocity = dynoFloat(fallVelocity ?? 0.02); - const dynoWanderScale = dynoFloat(wanderScale ?? 0.01); - const dynoWanderVariance = dynoFloat(wanderVariance ?? 2); - const dynoColor1 = dynoVec3(color1 ?? new THREE.Color(1, 1, 1)); - const dynoColor2 = dynoVec3(color2 ?? new THREE.Color(0.5, 0.5, 1)); - const dynoOpacity = dynoFloat(opacity ?? 1); - - const dynoTime = dynoFloat(0); - const globalOffset = dynoVec3(new THREE.Vector3(0, 0, 0)); - const dynoMin = dynoVec3(box.min); - const dynoMax = dynoVec3(box.max); - const dynoMinY = dynoFloat(minY ?? Number.NEGATIVE_INFINITY); - const minMax = sub(dynoMax, dynoMin); - const snow = new SplatGenerator({ - numSplats, - generator: dynoBlock( - { index: "int" }, - { gsplat: Gsplat }, - ({ index }) => { - if (!index) { - throw new Error("index not defined"); - } - const random = hashVec4(index); - const randomW = split(random).outputs.w; - let position = vec3(random); - - let size = fract(mul(randomW, dynoConst("float", 100))); - size = sin(mul(dynoLiteral("float", "PI"), size)); - size = add(dynoMinScale, mul(size, sub(dynoMaxScale, dynoMinScale))); - const scales = mul(size, dynoAnisoScale); - - const intensity = fract(mul(randomW, dynoConst("float", 10))); - const hue = fract(randomW); - const color = mix(dynoColor1, dynoColor2, hue); - const rgb = mul(color, intensity); - - const random2 = hashVec4( - combine({ - vectorType: "ivec2", - x: index, - y: dynoConst("int", 0x1ab5), - }), - ); - let perturb = vec3(random2); - let timeOffset = mul(split(random2).outputs.w, dynoWanderVariance); - timeOffset = add(dynoTime, timeOffset); - - position = add(position, globalOffset); - const modulo = mod( - position, - dynoConst("vec3", new THREE.Vector3(1, 1, 1)), - ); - position = add(dynoMin, mul(minMax, modulo)); - - const quaternion = dynoConst("vec4", new THREE.Quaternion(0, 0, 0, 1)); - - perturb = sin(add(vec3(timeOffset), perturb)); - perturb = mul(perturb, dynoWanderScale); - let center = add(position, perturb); - - let centerY = split(center).outputs.y; - centerY = max(dynoMinY, centerY); - center = combine({ vector: center, y: centerY }); - - let gsplat = combineGsplat({ - flags: dynoLiteral("uint", "GSPLAT_FLAG_ACTIVE"), - index: index, - center, - scales, - quaternion, - rgb, - opacity: dynoOpacity, - }); - gsplat = transformer.applyGsplat(gsplat); - return { gsplat }; - }, - { - globals: () => [defineGsplat], - }, - ), - update: ({ object, time, deltaTime }) => { - dynoTime.value = time; - const _updated = transformer.update(snow); - - const fallDelta = dynoFallDirection.value - .clone() - .multiplyScalar(dynoFallVelocity.value * deltaTime); - globalOffset.value.add(fallDelta); - - // Enable/disable splats based on opacity - object.visible = dynoOpacity.value > 0; - - onFrame?.({ object, time, deltaTime }); - snow.updateVersion(); - }, - }); - const transformer: SplatTransformer = new SplatTransformer(); - return { - snow, - min: dynoMin, - max: dynoMax, - minY: dynoMinY, - color1: dynoColor1, - color2: dynoColor2, - opacity: dynoOpacity, - fallVelocity: dynoFallVelocity, - wanderVariance: dynoWanderVariance, - wanderScale: dynoWanderScale, - fallDirection: dynoFallDirection, - minScale: dynoMinScale, - maxScale: dynoMaxScale, - anisoScale: dynoAnisoScale, - }; -} - -export type SNOW_RESULT_TYPE = ReturnType; diff --git a/src/generators/static.ts b/src/generators/static.ts deleted file mode 100644 index cba3c62..0000000 --- a/src/generators/static.ts +++ /dev/null @@ -1,117 +0,0 @@ -import * as THREE from "three"; -import { SplatGenerator, SplatTransformer } from "../SplatGenerator"; -import { - type DynoVal, - Gsplat, - add, - combine, - combineGsplat, - defineGsplat, - div, - dynoBlock, - dynoConst, - dynoFloat, - dynoLiteral, - floatBitsToInt, - hashVec3, - imod, - mul, - split, - sub, - vec3, -} from "../dyno"; - -export function staticBox({ - box, - cells, - dotScale, - color, - opacity, -}: { - box: THREE.Box3; - cells: THREE.Vector3; - dotScale: number; - color?: THREE.Color; - opacity?: number; -}) { - cells.x = Math.max(1, Math.round(cells.x)); - cells.y = Math.max(1, Math.round(cells.y)); - cells.z = Math.max(1, Math.round(cells.z)); - opacity = opacity ?? 1; - const numSplats = cells.x * cells.y * cells.z; - const dynoX = dynoConst("int", cells.x); - const dynoY = dynoConst("int", cells.y); - const dynoZ = dynoConst("int", cells.z); - - const dynoTime = dynoFloat(0); - const generator = new SplatGenerator({ - numSplats, - generator: dynoBlock( - { index: "int" }, - { gsplat: Gsplat }, - ({ index }) => { - if (!index) { - throw new Error("index is undefined"); - } - const cellX = imod(index, dynoX); - const index2 = div(index, dynoX); - const cellY = imod(index2, dynoY); - const cellZ = div(index2, dynoY); - const cell = combine({ - vectorType: "ivec3", - x: cellX, - y: cellY, - z: cellZ, - }); - - const intTime = floatBitsToInt(dynoTime); - const inputs = combine({ vectorType: "ivec2", x: index, y: intTime }); - const random = hashVec3(inputs); - const min = dynoConst("vec3", box.min); - const max = dynoConst("vec3", box.max); - const size = sub(max, min); - const coord = div(add(vec3(cell), random), dynoConst("vec3", cells)); - let r: DynoVal<"float">; - let g: DynoVal<"float">; - let b: DynoVal<"float">; - if (color) { - r = dynoConst("float", color.r); - g = dynoConst("float", color.g); - b = dynoConst("float", color.b); - } else { - ({ r, g, b } = split(coord).outputs); - } - const rgba = combine({ - vectorType: "vec4", - r, - g, - b, - a: dynoConst("float", opacity), - }); - const center = add(min, mul(size, coord)); - const scales = vec3(dynoConst("float", dotScale)); - const quaternion = dynoConst("vec4", new THREE.Quaternion(0, 0, 0, 1)); - let gsplat = combineGsplat({ - flags: dynoLiteral("uint", "GSPLAT_FLAG_ACTIVE"), - index: index, - center, - scales, - quaternion, - rgba, - }); - gsplat = transformer.applyGsplat(gsplat); - return { gsplat }; - }, - { - globals: () => [defineGsplat], - }, - ), - update: ({ time }) => { - dynoTime.value = time; - const _updated = transformer.update(generator); - generator.updateVersion(); - }, - }); - const transformer: SplatTransformer = new SplatTransformer(); - return generator; -} diff --git a/src/hands.ts b/src/hands.ts deleted file mode 100644 index 38f2932..0000000 --- a/src/hands.ts +++ /dev/null @@ -1,472 +0,0 @@ -import { - Color, - Matrix4, - type Object3D, - Quaternion, - Vector3, - type WebXRManager, -} from "three"; -import { SplatMesh } from "./SplatMesh"; - -// Experimental WebXR hand tracking and movement - -const DEFAULT_MOVE_INERTIA = 0.5; -const DEFAULT_ROTATE_INERTIA = 0.5; -const TOUCH_BIAS = 0.0; - -export enum JointEnum { - w = "wrist", - t0 = "thumb-metacarpal", - t1 = "thumb-phalanx-proximal", - t2 = "thumb-phalanx-distal", - t3 = "thumb-tip", - i0 = "index-finger-metacarpal", - i1 = "index-finger-phalanx-proximal", - i2 = "index-finger-phalanx-intermediate", - i3 = "index-finger-phalanx-distal", - i4 = "index-finger-tip", - m0 = "middle-finger-metacarpal", - m1 = "middle-finger-phalanx-proximal", - m2 = "middle-finger-phalanx-intermediate", - m3 = "middle-finger-phalanx-distal", - m4 = "middle-finger-tip", - r0 = "ring-finger-metacarpal", - r1 = "ring-finger-phalanx-proximal", - r2 = "ring-finger-phalanx-intermediate", - r3 = "ring-finger-phalanx-distal", - r4 = "ring-finger-tip", - p0 = "pinky-finger-metacarpal", - p1 = "pinky-finger-phalanx-proximal", - p2 = "pinky-finger-phalanx-intermediate", - p3 = "pinky-finger-phalanx-distal", - p4 = "pinky-finger-tip", -} -export type JointId = keyof typeof JointEnum; -export const JOINT_IDS = Object.keys(JointEnum) as JointId[]; -export const NUM_JOINTS = JOINT_IDS.length; - -export const JOINT_INDEX: { [key in JointId]: number } = { - w: 0, - t0: 1, - t1: 2, - t2: 3, - t3: 4, - i0: 5, - i1: 6, - i2: 7, - i3: 8, - i4: 9, - m0: 10, - m1: 11, - m2: 12, - m3: 13, - m4: 14, - r0: 15, - r1: 16, - r2: 17, - r3: 18, - r4: 19, - p0: 20, - p1: 21, - p2: 22, - p3: 23, - p4: 24, -}; - -export const JOINT_RADIUS: { [key in JointId]: number } = { - w: 0.02, - t0: 0.02, - t1: 0.014, - t2: 0.0115, - t3: 0.0085, - i0: 0.022, - i1: 0.012, - i2: 0.0085, - i3: 0.0075, - i4: 0.0065, - m0: 0.021, - m1: 0.012, - m2: 0.008, - m3: 0.0075, - m4: 0.0065, - r0: 0.019, - r1: 0.011, - r2: 0.0075, - r3: 0.007, - r4: 0.006, - p0: 0.012, - p1: 0.01, - p2: 0.007, - p3: 0.0065, - p4: 0.0055, -}; - -export const JOINT_SEGMENTS: JointId[][] = [ - ["w", "t0", "t1", "t2", "t3"], - ["w", "i0", "i1", "i2", "i3", "i4"], - ["w", "m0", "m1", "m2", "m3", "m4"], - ["w", "r0", "r1", "r2", "r3", "r4"], - ["w", "p0", "p1", "p2", "p3", "p4"], -]; - -export const JOINT_SEGMENT_STEPS: number[][] = [ - [8, 10, 8, 6], - [8, 19, 14, 8, 6], - [8, 19, 14, 8, 6], - [8, 19, 14, 8, 6], - [8, 19, 14, 8, 6], -]; - -export const JOINT_TIPS: JointId[] = ["t3", "i4", "m4", "r4", "p4"]; -export const FINGER_TIPS: JointId[] = ["i4", "m4", "r4", "p4"]; - -export enum Hand { - left = "left", - right = "right", -} -export const HANDS = Object.keys(Hand) as Hand[]; - -export type Joint = { - position: Vector3; - quaternion: Quaternion; - radius: number; -}; - -export type HandJoints = { [key in JointId]?: Joint }; -export type HandsJoints = { [key in Hand]?: HandJoints }; - -export class XrHands { - hands: HandsJoints = {}; - last: HandsJoints = {}; - - values: Record = {}; - tests: Record = {}; - lastTests: Record = {}; - - updated = false; - - update({ xr, xrFrame }: { xr: WebXRManager; xrFrame: XRFrame }) { - const xrSession = xr.getSession(); - if (!xrSession) { - return; - } - const referenceSpace = xr.getReferenceSpace(); - if (!referenceSpace) { - return; - } - if (!xrFrame.getJointPose) { - return; - } - - this.last = this.hands; - this.lastTests = this.tests; - - this.hands = {}; - this.values = {}; - this.tests = {}; - - for (const inputSource of xrSession.inputSources) { - if (!inputSource.hand) { - continue; - } - - const hand = inputSource.handedness as Hand; - this.hands[hand] = {}; - - // Iterate over JointId - for (const jointId of JOINT_IDS) { - const jointSpace = inputSource.hand.get(JointEnum[jointId]); - if (jointSpace) { - const jointPose = xrFrame.getJointPose(jointSpace, referenceSpace); - if (jointPose) { - const { position, orientation } = jointPose.transform; - this.hands[hand][jointId] = { - position: new Vector3(position.x, position.y, position.z), - quaternion: new Quaternion( - orientation.x, - orientation.y, - orientation.z, - orientation.w, - ), - radius: jointPose.radius || 0.001, - }; - } - } - } - } - - for (const hand of HANDS) { - for (const { key, value } of [ - { key: `${hand}AllTips`, value: this.allTipsTouching(hand) }, - { - key: `${hand}IndexThumb`, - value: this.touching(hand, "i4", hand, "t3"), - }, - { - key: `${hand}MiddleThumb`, - value: this.touching(hand, "m4", hand, "t3"), - }, - { - key: `${hand}RingThumb`, - value: this.touching(hand, "r4", hand, "t3"), - }, - { - key: `${hand}PinkyThumb`, - value: this.touching(hand, "p4", hand, "t3"), - }, - { key: `${hand}TriTips`, value: this.triTipsTouching(hand) }, - ]) { - this.values[key] = value; - this.tests[key] = - value === 1.0 - ? true - : value === 0.0 - ? false - : (this.lastTests[key] ?? false); - } - } - } - - makeGhostMesh(): SplatMesh { - const center = new Vector3(); - const scales = new Vector3(0.01, 0.01, 0.01); - const quaternion = new Quaternion(0, 0, 0, 1); - const color = new Color(1, 1, 1); - const CYCLE = Math.PI * 3; - const WHITE = new Color(1, 1, 1); - let opacity = 1.0; - - const mesh = new SplatMesh({ - onFrame: () => { - let splatIndex = 0; - for (const handedness of HANDS) { - const xrHand = this.hands[handedness]; - for (const [index, segment] of JOINT_SEGMENTS.entries()) { - for (let i = 1; i < segment.length; ++i) { - const segmentSplats = JOINT_SEGMENT_STEPS[index][i - 1] * 2; - const lastSegment = i + 1 === segment.length; - const jointA = xrHand?.[segment[i - 1]]; - const jointB = xrHand?.[segment[i]]; - - for (let j = 0; j < segmentSplats; ++j) { - const t = (j + 0.5) / segmentSplats; - opacity = 0.0; - if (jointA && jointB) { - center.copy(jointA.position).lerp(jointB.position, t); - quaternion - .copy(jointA.quaternion) - .slerp(jointB.quaternion, t); - const radiusA = JOINT_RADIUS[segment[i - 1]]; - const radiusB = JOINT_RADIUS[segment[i]]; - let radius = (1 - t) * radiusA + t * radiusB; - if (lastSegment && t > 0.8) { - // Round out finger tips - radius *= Math.sqrt(1 - ((t - 0.8) / 0.2) ** 2); - } - scales.set(0.65 * radius, 0.5 * radius, 0.003); - color.set( - 0.55 + 0.45 * Math.sin(center.x * CYCLE), - 0.55 + 0.45 * Math.sin(center.y * CYCLE), - 0.55 + 0.45 * Math.sin(center.z * CYCLE), - ); - if (handedness === "right") { - color.set(1 - color.r, 1 - color.g, 1 - color.b); - } - opacity = 0.75; - } - mesh.packedSplats.setSplat( - splatIndex, - center, - scales, - quaternion, - opacity, - color, - ); - splatIndex += 1; - } - } - } - } - mesh.packedSplats.numSplats = splatIndex; - mesh.packedSplats.needsUpdate = true; - mesh.numSplats = splatIndex; - mesh.updateVersion(); - }, - }); - return mesh; - } - - distance( - handA: Hand, - jointA: JointId, - handB: Hand, - jointB: JointId, - last = false, - ): number { - const hA = last ? this.last[handA] : this.hands[handA]; - const hB = last ? this.last[handB] : this.hands[handB]; - const jA = hA?.[jointA]; - const jB = hB?.[jointB]; - if (!jA || !jB) { - return Number.POSITIVE_INFINITY; - } - return jA.position.distanceTo(jB.position); - } - - separation( - handA: Hand, - jointA: JointId, - handB: Hand, - jointB: JointId, - last = false, - ): number { - const d = this.distance(handA, jointA, handB, jointB, last); - if (d === Number.POSITIVE_INFINITY) { - return Number.POSITIVE_INFINITY; - } - return d - JOINT_RADIUS[jointA] - JOINT_RADIUS[jointB]; - } - - touching( - handA: Hand, - jointA: JointId, - handB: Hand, - jointB: JointId, - last = false, - ): number { - const d = this.separation(handA, jointA, handB, jointB, last); - if (d === Number.POSITIVE_INFINITY) { - return Number.POSITIVE_INFINITY; - } - return 1 - Math.max(0, Math.min(1, d / 0.01 - TOUCH_BIAS)); - } - - allTipsTouching(hand: Hand, last = false): number { - return Math.min( - this.touching(hand, "t3", hand, "i4", last), - this.touching(hand, "i4", hand, "m4", last), - this.touching(hand, "m4", hand, "r4", last), - this.touching(hand, "r4", hand, "p4", last), - // this.touching(hand, "p4", hand, "t3", last), - ); - } - - triTipsTouching(hand: Hand, last = false): number { - return Math.min( - this.touching(hand, "t3", hand, "i4", last), - this.touching(hand, "i4", hand, "m4", last), - this.touching(hand, "m4", hand, "t3", last), - ); - } -} - -export class HandMovement { - xrHands: XrHands; - control: Object3D; - moveInertia: number; - rotateInertia: number; - - lastGrip: { [key in Hand]?: Vector3 } = {}; - lastPivot: Vector3 = new Vector3(); - rotateVelocity = 0; - velocity: Vector3 = new Vector3(); - - constructor({ - xrHands, - control, - moveInertia, - rotateInertia, - }: { - xrHands: XrHands; - control: Object3D; - moveInertia?: number; - rotateInertia?: number; - }) { - this.xrHands = xrHands; - this.control = control; - this.moveInertia = moveInertia ?? DEFAULT_MOVE_INERTIA; - this.rotateInertia = rotateInertia ?? DEFAULT_ROTATE_INERTIA; - } - - update(deltaTime: number) { - const grip: { [key in Hand]?: Vector3 } = {}; - for (const handedness of HANDS) { - const hand = this.xrHands.hands[handedness]; - if (hand && this.xrHands.tests[`${handedness}MiddleThumb`]) { - grip[handedness] = new Vector3() - .add(hand.t3?.position ?? new Vector3()) - .add(hand.i4?.position ?? new Vector3()) - .add(hand.m4?.position ?? new Vector3()) - .add(hand.r4?.position ?? new Vector3()) - .add(hand.p4?.position ?? new Vector3()) - .multiplyScalar(1 / 5); - } - } - - if (grip.left && grip.right && this.lastGrip.left && this.lastGrip.right) { - const mid = grip.left.clone().add(grip.right).multiplyScalar(0.5); - const lastMid = this.lastGrip.left - .clone() - .add(this.lastGrip.right) - .multiplyScalar(0.5); - this.lastPivot = mid; - - const delta = mid.clone().applyMatrix4(this.control.matrix); - delta.sub(lastMid.clone().applyMatrix4(this.control.matrix)); - delta.multiplyScalar(1 / deltaTime); - this.velocity.lerp(delta, 1 - Math.exp(-20 * deltaTime)); - - const angle = Math.atan2(grip.left.z - mid.z, grip.left.x - mid.x); - const lastAngle = Math.atan2( - this.lastGrip.left.z - lastMid.z, - this.lastGrip.left.x - lastMid.x, - ); - // Find closest rotation over circle between angle and lastAngle - let closestAngle = angle - lastAngle; - if (closestAngle > Math.PI) { - closestAngle -= Math.PI * 2; - } else if (closestAngle < -Math.PI) { - closestAngle += Math.PI * 2; - } - const rotateVelocity = closestAngle / deltaTime; - - const blend = Math.exp(-20 * deltaTime); - this.rotateVelocity = - this.rotateVelocity * blend + rotateVelocity * (1 - blend); - } else { - this.rotateVelocity *= Math.exp(-deltaTime / this.rotateInertia); - - if (grip.left && this.lastGrip.left) { - const delta = grip.left.clone().applyMatrix4(this.control.matrix); - delta.sub(this.lastGrip.left.clone().applyMatrix4(this.control.matrix)); - delta.multiplyScalar(1 / deltaTime); - this.velocity.lerp(delta, 1 - Math.exp(-20 * deltaTime)); - } else if (grip.right && this.lastGrip.right) { - const delta = grip.right.clone().applyMatrix4(this.control.matrix); - delta.sub( - this.lastGrip.right.clone().applyMatrix4(this.control.matrix), - ); - delta.multiplyScalar(1 / deltaTime); - this.velocity.lerp(delta, 1 - Math.exp(-20 * deltaTime)); - } else { - this.velocity.multiplyScalar(Math.exp(-deltaTime / this.moveInertia)); - } - } - - const negPivot = this.lastPivot.clone().negate(); - const rotate = new Matrix4() - .makeTranslation(negPivot) - .premultiply(new Matrix4().makeRotationY(this.rotateVelocity * deltaTime)) - .premultiply(new Matrix4().makeTranslation(this.lastPivot)); - this.control.matrix.multiply(rotate); - this.control.matrix.decompose( - this.control.position, - this.control.quaternion, - this.control.scale, - ); - this.control.updateMatrixWorld(true); - - this.control.position.sub(this.velocity.clone().multiplyScalar(deltaTime)); - this.lastGrip = grip; - } -} diff --git a/src/index.ts b/src/index.ts index 73047f5..24e3aa5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,97 +1,17 @@ -export { SparkRenderer, type SparkRendererOptions } from "./SparkRenderer"; -export { SparkViewpoint, type SparkViewpointOptions } from "./SparkViewpoint"; - -export * as dyno from "./dyno"; - -export { RgbaArray } from "./RgbaArray"; - export { SplatLoader, - unpackSplats, SplatFileType, - getSplatFileType, - isPcSogs, } from "./SplatLoader"; -export { PlyReader } from "./ply"; -export { SpzReader, SpzWriter, transcodeSpz } from "./spz"; -export { PackedSplats, type PackedSplatsOptions } from "./PackedSplats"; -export { - SplatGenerator, - type GsplatGenerator, - SplatModifier, - type GsplatModifier, - SplatTransformer, -} from "./SplatGenerator"; -export { SplatAccumulator, type GeneratorMapping } from "./SplatAccumulator"; -export { Readback, type Rgba8Readback, type ReadbackBuffer } from "./Readback"; +export { Splat } from "./Splat"; +import "./procedural"; +export { BatchedSplat } from "./BatchedSplat"; -export { - SplatMesh, - type SplatMeshOptions, - type SplatMeshContext, -} from "./SplatMesh"; -export { SplatSkinning, type SplatSkinningOptions } from "./SplatSkinning"; -export { - SplatEdit, - type SplatEditOptions, - SplatEditSdf, - type SplatEditSdfOptions, - SplatEditSdfType, - SplatEditRgbaBlendMode, - SplatEdits, -} from "./SplatEdit"; +export * from "./raycast"; +export * as SplatUtils from "./SplatUtils"; -export { - constructGrid, - constructAxes, - constructSpherePoints, - imageSplats, - textSplats, -} from "./splatConstructors"; - -export * as generators from "./generators"; -export * as modifiers from "./modifiers"; - -export { VRButton } from "./vrButton"; -export { - type JointId, - JointEnum, - JOINT_IDS, - NUM_JOINTS, - JOINT_INDEX, - JOINT_RADIUS, - JOINT_SEGMENTS, - JOINT_SEGMENT_STEPS, - JOINT_TIPS, - FINGER_TIPS, - Hand, - HANDS, - type Joint, - type HandJoints, - type HandsJoints, - XrHands, - HandMovement, -} from "./hands"; - -export { SparkControls, FpsMovement, PointerControls } from "./controls"; - -export { - isMobile, - isAndroid, - isOculus, - flipPixels, - pixelsToPngUrl, - toHalf, - fromHalf, - floatToUint8, - floatToSint8, - Uint8ToFloat, - Sint8ToFloat, - setPackedSplat, - unpackSplat, -} from "./utils"; -export * as utils from "./utils"; +export type { SplatEncoder, ResizableSplatEncoder } from "./encoding/encoder"; +export { PackedSplats } from "./encoding/PackedSplats"; +export { ExtendedSplats } from "./encoding/ExtendedSplats"; -export { LN_SCALE_MIN, LN_SCALE_MAX } from "./defines"; -export * as defines from "./defines"; +export { transcodeSpz } from "./transcode"; diff --git a/src/ksplat.ts b/src/ksplat.ts deleted file mode 100644 index 17b1be3..0000000 --- a/src/ksplat.ts +++ /dev/null @@ -1,636 +0,0 @@ -import type { SplatEncoding } from "./PackedSplats"; -import { - computeMaxSplats, - encodeSh1Rgb, - encodeSh2Rgb, - encodeSh3Rgb, - fromHalf, - setPackedSplat, -} from "./utils"; - -type KsplatCompression = { - bytesPerCenter: number; - bytesPerScale: number; - bytesPerRotation: number; - bytesPerColor: number; - bytesPerSphericalHarmonicsComponent: number; - scaleOffsetBytes: number; - rotationOffsetBytes: number; - colorOffsetBytes: number; - sphericalHarmonicsOffsetBytes: number; - scaleRange: number; -}; - -const KSPLAT_COMPRESSION: Record = { - 0: { - bytesPerCenter: 12, - bytesPerScale: 12, - bytesPerRotation: 16, - bytesPerColor: 4, - bytesPerSphericalHarmonicsComponent: 4, - scaleOffsetBytes: 12, - rotationOffsetBytes: 24, - colorOffsetBytes: 40, - sphericalHarmonicsOffsetBytes: 44, - scaleRange: 1, - }, - 1: { - bytesPerCenter: 6, - bytesPerScale: 6, - bytesPerRotation: 8, - bytesPerColor: 4, - bytesPerSphericalHarmonicsComponent: 2, - scaleOffsetBytes: 6, - rotationOffsetBytes: 12, - colorOffsetBytes: 20, - sphericalHarmonicsOffsetBytes: 24, - scaleRange: 32767, - }, - 2: { - bytesPerCenter: 6, - bytesPerScale: 6, - bytesPerRotation: 8, - bytesPerColor: 4, - bytesPerSphericalHarmonicsComponent: 1, - scaleOffsetBytes: 6, - rotationOffsetBytes: 12, - colorOffsetBytes: 20, - sphericalHarmonicsOffsetBytes: 24, - scaleRange: 32767, - }, -}; - -const KSPLAT_SH_DEGREE_TO_COMPONENTS: Record = { - 0: 0, - 1: 9, - 2: 24, - 3: 45, -}; - -export function decodeKsplat( - fileBytes: Uint8Array, - initNumSplats: (numSplats: number) => void, - splatCallback: ( - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, - ) => void, - shCallback?: ( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, - ) => void, -) { - const HEADER_BYTES = 4096; - const SECTION_BYTES = 1024; - - let headerOffset = 0; - const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES); - headerOffset += HEADER_BYTES; - - const versionMajor = header.getUint8(0); - const versionMinor = header.getUint8(1); - if (versionMajor !== 0 || versionMinor < 1) { - throw new Error( - `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`, - ); - } - const maxSectionCount = header.getUint32(4, true); - // const sectionCount = header.getUint32(8, true); - // const maxSplatCount = header.getUint32(12, true); - const splatCount = header.getUint32(16, true); - const compressionLevel = header.getUint16(20, true); - if (compressionLevel < 0 || compressionLevel > 2) { - throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`); - } - // const sceneCenterX = header.getFloat32(24, true); - // const sceneCenterY = header.getFloat32(28, true); - // const sceneCenterZ = header.getFloat32(32, true); - const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5; - const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5; - - const numSplats = splatCount; - initNumSplats(numSplats); - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES; - - for (let section = 0; section < maxSectionCount; ++section) { - const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES); - headerOffset += SECTION_BYTES; - - const sectionSplatCount = section.getUint32(0, true); - const sectionMaxSplatCount = section.getUint32(4, true); - const bucketSize = section.getUint32(8, true); - const bucketCount = section.getUint32(12, true); - const bucketBlockSize = section.getFloat32(16, true); - const bucketStorageSizeBytes = section.getUint16(20, true); - const compressionScaleRange = - (section.getUint32(24, true) || - KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ?? - 1; - const fullBucketCount = section.getUint32(32, true); - const fullBucketSplats = fullBucketCount * bucketSize; - const partiallyFilledBucketCount = section.getUint32(36, true); - const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4; - const bucketsStorageSizeBytes = - bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes; - const sphericalHarmonicsDegree = section.getUint16(40, true); - const shComponents = - KSPLAT_SH_DEGREE_TO_COMPONENTS[sphericalHarmonicsDegree]; - - const { - bytesPerCenter, - bytesPerScale, - bytesPerRotation, - bytesPerColor, - bytesPerSphericalHarmonicsComponent, - scaleOffsetBytes, - rotationOffsetBytes, - colorOffsetBytes, - sphericalHarmonicsOffsetBytes, - } = KSPLAT_COMPRESSION[compressionLevel]; - const bytesPerSplat = - bytesPerCenter + - bytesPerScale + - bytesPerRotation + - bytesPerColor + - shComponents * bytesPerSphericalHarmonicsComponent; - const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount; - const storageSizeBytes = - splatDataStorageSizeBytes + bucketsStorageSizeBytes; - - const sh1Index = [0, 3, 6, 1, 4, 7, 2, 5, 8]; - const sh2Index = [ - 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23, - ]; - const sh3Index = [ - 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43, - 30, 37, 44, - ]; - const sh1 = - sphericalHarmonicsDegree >= 1 ? new Float32Array(3 * 3) : undefined; - const sh2 = - sphericalHarmonicsDegree >= 2 ? new Float32Array(5 * 3) : undefined; - const sh3 = - sphericalHarmonicsDegree >= 3 ? new Float32Array(7 * 3) : undefined; - - const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange; - const bucketsBase = sectionBase + bucketsMetaDataSizeBytes; - const dataBase = sectionBase + bucketsStorageSizeBytes; - const data = new DataView( - fileBytes.buffer, - dataBase, - splatDataStorageSizeBytes, - ); - const bucketArray = new Float32Array( - fileBytes.buffer, - bucketsBase, - bucketCount * 3, - ); - const partiallyFilledBucketLengths = new Uint32Array( - fileBytes.buffer, - sectionBase, - partiallyFilledBucketCount, - ); - - function getSh(splatOffset: number, component: number) { - if (compressionLevel === 0) { - return data.getFloat32( - splatOffset + sphericalHarmonicsOffsetBytes + component * 4, - true, - ); - } - if (compressionLevel === 1) { - return fromHalf( - data.getUint16( - splatOffset + sphericalHarmonicsOffsetBytes + component * 2, - true, - ), - ); - } - const t = - data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) / - 255; - return ( - minSphericalHarmonicsCoeff + - t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff) - ); - } - - let partialBucketIndex = fullBucketCount; - let partialBucketBase = fullBucketSplats; - - for (let i = 0; i < sectionSplatCount; ++i) { - const splatOffset = i * bytesPerSplat; - - let bucketIndex: number; - if (i < fullBucketSplats) { - bucketIndex = Math.floor(i / bucketSize); - } else { - const bucketLength = - partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount]; - if (i >= partialBucketBase + bucketLength) { - partialBucketIndex += 1; - partialBucketBase += bucketLength; - } - bucketIndex = partialBucketIndex; - } - - const x = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 0, true) - : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 0]; - const y = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 4, true) - : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 1]; - const z = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 8, true) - : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 2]; - - const scaleX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true)); - const scaleY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true)); - const scaleZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true)); - - const quatW = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 0, true), - ); - const quatX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 2, true), - ); - const quatY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 4, true), - ); - const quatZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 6, true), - ); - - const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255; - const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255; - const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255; - const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255; - - splatCallback( - i, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ); - - if (sphericalHarmonicsDegree >= 1 && sh1) { - for (const [i, key] of sh1Index.entries()) { - sh1[i] = getSh(splatOffset, key); - } - if (sh2) { - for (const [i, key] of sh2Index.entries()) { - sh2[i] = getSh(splatOffset, key); - } - } - if (sh3) { - for (const [i, key] of sh3Index.entries()) { - sh3[i] = getSh(splatOffset, key); - } - } - shCallback?.(i, sh1, sh2, sh3); - } - } - sectionBase += storageSizeBytes; - } -} - -export function unpackKsplat( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): { - packedArray: Uint32Array; - numSplats: number; - extra: Record; -} { - const HEADER_BYTES = 4096; - const SECTION_BYTES = 1024; - - let headerOffset = 0; - const header = new DataView(fileBytes.buffer, headerOffset, HEADER_BYTES); - headerOffset += HEADER_BYTES; - - const versionMajor = header.getUint8(0); - const versionMinor = header.getUint8(1); - if (versionMajor !== 0 || versionMinor < 1) { - throw new Error( - `Unsupported .ksplat version: ${versionMajor}.${versionMinor}`, - ); - } - const maxSectionCount = header.getUint32(4, true); - // const sectionCount = header.getUint32(8, true); - // const maxSplatCount = header.getUint32(12, true); - const splatCount = header.getUint32(16, true); - const compressionLevel = header.getUint16(20, true); - if (compressionLevel < 0 || compressionLevel > 2) { - throw new Error(`Invalid .ksplat compression level: ${compressionLevel}`); - } - // const sceneCenterX = header.getFloat32(24, true); - // const sceneCenterY = header.getFloat32(28, true); - // const sceneCenterZ = header.getFloat32(32, true); - const minSphericalHarmonicsCoeff = header.getFloat32(36, true) || -1.5; - const maxSphericalHarmonicsCoeff = header.getFloat32(40, true) || 1.5; - - const numSplats = splatCount; - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - let sectionBase = HEADER_BYTES + maxSectionCount * SECTION_BYTES; - - for (let section = 0; section < maxSectionCount; ++section) { - const section = new DataView(fileBytes.buffer, headerOffset, SECTION_BYTES); - headerOffset += SECTION_BYTES; - - const sectionSplatCount = section.getUint32(0, true); - const sectionMaxSplatCount = section.getUint32(4, true); - const bucketSize = section.getUint32(8, true); - const bucketCount = section.getUint32(12, true); - const bucketBlockSize = section.getFloat32(16, true); - const bucketStorageSizeBytes = section.getUint16(20, true); - const compressionScaleRange = - (section.getUint32(24, true) || - KSPLAT_COMPRESSION[compressionLevel]?.scaleRange) ?? - 1; - const fullBucketCount = section.getUint32(32, true); - const fullBucketSplats = fullBucketCount * bucketSize; - const partiallyFilledBucketCount = section.getUint32(36, true); - const bucketsMetaDataSizeBytes = partiallyFilledBucketCount * 4; - const bucketsStorageSizeBytes = - bucketStorageSizeBytes * bucketCount + bucketsMetaDataSizeBytes; - const sphericalHarmonicsDegree = section.getUint16(40, true); - const shComponents = - KSPLAT_SH_DEGREE_TO_COMPONENTS[sphericalHarmonicsDegree]; - - const { - bytesPerCenter, - bytesPerScale, - bytesPerRotation, - bytesPerColor, - bytesPerSphericalHarmonicsComponent, - scaleOffsetBytes, - rotationOffsetBytes, - colorOffsetBytes, - sphericalHarmonicsOffsetBytes, - } = KSPLAT_COMPRESSION[compressionLevel]; - const bytesPerSplat = - bytesPerCenter + - bytesPerScale + - bytesPerRotation + - bytesPerColor + - shComponents * bytesPerSphericalHarmonicsComponent; - const splatDataStorageSizeBytes = bytesPerSplat * sectionMaxSplatCount; - const storageSizeBytes = - splatDataStorageSizeBytes + bucketsStorageSizeBytes; - - const sh1Index = [0, 3, 6, 1, 4, 7, 2, 5, 8]; - const sh2Index = [ - 9, 14, 19, 10, 15, 20, 11, 16, 21, 12, 17, 22, 13, 18, 23, - ]; - const sh3Index = [ - 24, 31, 38, 25, 32, 39, 26, 33, 40, 27, 34, 41, 28, 35, 42, 29, 36, 43, - 30, 37, 44, - ]; - const sh1 = - sphericalHarmonicsDegree >= 1 ? new Float32Array(3 * 3) : undefined; - const sh2 = - sphericalHarmonicsDegree >= 2 ? new Float32Array(5 * 3) : undefined; - const sh3 = - sphericalHarmonicsDegree >= 3 ? new Float32Array(7 * 3) : undefined; - - const compressionScaleFactor = bucketBlockSize / 2 / compressionScaleRange; - const bucketsBase = sectionBase + bucketsMetaDataSizeBytes; - const dataBase = sectionBase + bucketsStorageSizeBytes; - const data = new DataView( - fileBytes.buffer, - dataBase, - splatDataStorageSizeBytes, - ); - const bucketArray = new Float32Array( - fileBytes.buffer, - bucketsBase, - bucketCount * 3, - ); - const partiallyFilledBucketLengths = new Uint32Array( - fileBytes.buffer, - sectionBase, - partiallyFilledBucketCount, - ); - - function getSh(splatOffset: number, component: number) { - if (compressionLevel === 0) { - return data.getFloat32( - splatOffset + sphericalHarmonicsOffsetBytes + component * 4, - true, - ); - } - if (compressionLevel === 1) { - return fromHalf( - data.getUint16( - splatOffset + sphericalHarmonicsOffsetBytes + component * 2, - true, - ), - ); - } - const t = - data.getUint8(splatOffset + sphericalHarmonicsOffsetBytes + component) / - 255; - return ( - minSphericalHarmonicsCoeff + - t * (maxSphericalHarmonicsCoeff - minSphericalHarmonicsCoeff) - ); - } - - let partialBucketIndex = fullBucketCount; - let partialBucketBase = fullBucketSplats; - - for (let i = 0; i < sectionSplatCount; ++i) { - const splatOffset = i * bytesPerSplat; - - let bucketIndex: number; - if (i < fullBucketSplats) { - bucketIndex = Math.floor(i / bucketSize); - } else { - const bucketLength = - partiallyFilledBucketLengths[partialBucketIndex - fullBucketCount]; - if (i >= partialBucketBase + bucketLength) { - partialBucketIndex += 1; - partialBucketBase += bucketLength; - } - bucketIndex = partialBucketIndex; - } - - const x = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 0, true) - : (data.getUint16(splatOffset + 0, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 0]; - const y = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 4, true) - : (data.getUint16(splatOffset + 2, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 1]; - const z = - compressionLevel === 0 - ? data.getFloat32(splatOffset + 8, true) - : (data.getUint16(splatOffset + 4, true) - compressionScaleRange) * - compressionScaleFactor + - bucketArray[3 * bucketIndex + 2]; - - const scaleX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 0, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 0, true)); - const scaleY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 4, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 2, true)); - const scaleZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + scaleOffsetBytes + 8, true) - : fromHalf(data.getUint16(splatOffset + scaleOffsetBytes + 4, true)); - - const quatW = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 0, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 0, true), - ); - const quatX = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 4, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 2, true), - ); - const quatY = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 8, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 4, true), - ); - const quatZ = - compressionLevel === 0 - ? data.getFloat32(splatOffset + rotationOffsetBytes + 12, true) - : fromHalf( - data.getUint16(splatOffset + rotationOffsetBytes + 6, true), - ); - - const r = data.getUint8(splatOffset + colorOffsetBytes + 0) / 255; - const g = data.getUint8(splatOffset + colorOffsetBytes + 1) / 255; - const b = data.getUint8(splatOffset + colorOffsetBytes + 2) / 255; - const opacity = data.getUint8(splatOffset + colorOffsetBytes + 3) / 255; - - setPackedSplat( - packedArray, - i, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - splatEncoding, - ); - - if (sphericalHarmonicsDegree >= 1) { - if (sh1) { - if (!extra.sh1) { - extra.sh1 = new Uint32Array(numSplats * 2); - } - for (const [i, key] of sh1Index.entries()) { - sh1[i] = getSh(splatOffset, key); - } - encodeSh1Rgb(extra.sh1 as Uint32Array, i, sh1, splatEncoding); - } - if (sh2) { - if (!extra.sh2) { - extra.sh2 = new Uint32Array(numSplats * 4); - } - for (const [i, key] of sh2Index.entries()) { - sh2[i] = getSh(splatOffset, key); - } - encodeSh2Rgb(extra.sh2 as Uint32Array, i, sh2, splatEncoding); - } - if (sh3) { - if (!extra.sh3) { - extra.sh3 = new Uint32Array(numSplats * 4); - } - for (const [i, key] of sh3Index.entries()) { - sh3[i] = getSh(splatOffset, key); - } - encodeSh3Rgb(extra.sh3 as Uint32Array, i, sh3, splatEncoding); - } - } - } - sectionBase += storageSizeBytes; - } - return { packedArray, numSplats, extra }; -} diff --git a/src/modifiers.ts b/src/modifiers.ts deleted file mode 100644 index cb325e3..0000000 --- a/src/modifiers.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from "./modifiers/normalColor"; -export * from "./modifiers/depthColor"; diff --git a/src/modifiers/depthColor.ts b/src/modifiers/depthColor.ts deleted file mode 100644 index 9d788dd..0000000 --- a/src/modifiers/depthColor.ts +++ /dev/null @@ -1,60 +0,0 @@ -import type { SplatTransformer } from "../SplatGenerator"; -import type { SplatMesh } from "../SplatMesh"; -import { - type DynoVal, - Gsplat, - combineGsplat, - dynoBlock, - dynoConst, - neg, - normalizedDepth, - select, - split, - splitGsplat, - sub, -} from "../dyno"; - -export function makeDepthColorModifier( - splatToView: SplatTransformer, - minDepth: DynoVal<"float">, - maxDepth: DynoVal<"float">, - reverse: DynoVal<"bool">, -) { - return dynoBlock({ gsplat: Gsplat }, { gsplat: Gsplat }, ({ gsplat }) => { - if (!gsplat) { - throw new Error("No gsplat input"); - } - let { center } = splitGsplat(gsplat).outputs; - center = splatToView.apply(center); - const { z } = split(center).outputs; - let depth = normalizedDepth(neg(z), minDepth, maxDepth); - depth = select(reverse, sub(dynoConst("float", 1), depth), depth); - - gsplat = combineGsplat({ gsplat, r: depth, g: depth, b: depth }); - return { gsplat }; - }); -} - -export function setDepthColor( - splats: SplatMesh, - minDepth: number, - maxDepth: number, - reverse?: boolean, -) { - splats.enableWorldToView = true; - const dynoMinDepth = dynoConst("float", minDepth); - const dynoMaxDepth = dynoConst("float", maxDepth); - const dynoReverse = dynoConst("bool", reverse ?? false); - splats.worldModifier = makeDepthColorModifier( - splats.context.worldToView, - dynoMinDepth, - dynoMaxDepth, - dynoReverse, - ); - splats.updateGenerator(); - return { - minDepth: dynoMinDepth, - maxDepth: dynoMaxDepth, - reverse: dynoReverse, - }; -} diff --git a/src/modifiers/normalColor.ts b/src/modifiers/normalColor.ts deleted file mode 100644 index ac09541..0000000 --- a/src/modifiers/normalColor.ts +++ /dev/null @@ -1,46 +0,0 @@ -import type { SplatTransformer } from "../SplatGenerator"; -import type { SplatMesh } from "../SplatMesh"; -import { - Gsplat, - add, - combineGsplat, - dot, - dynoBlock, - dynoConst, - greaterThanEqual, - gsplatNormal, - mul, - neg, - select, - splitGsplat, -} from "../dyno"; - -export function makeNormalColorModifier(splatToView: SplatTransformer) { - return dynoBlock({ gsplat: Gsplat }, { gsplat: Gsplat }, ({ gsplat }) => { - if (!gsplat) { - throw new Error("No gsplat input"); - } - let normal = gsplatNormal(gsplat); - - const viewGsplat = splatToView.applyGsplat(gsplat); - const viewCenter = splitGsplat(viewGsplat).outputs.center; - const viewNormal = gsplatNormal(viewGsplat); - const splatDot = dot(viewCenter, viewNormal); - - const sameDir = greaterThanEqual(splatDot, dynoConst("float", 0)); - normal = select(sameDir, neg(normal), normal); - const rgb = add( - mul(normal, dynoConst("float", 0.5)), - dynoConst("float", 0.5), - ); - - gsplat = combineGsplat({ gsplat, rgb }); - return { gsplat }; - }); -} - -export function setWorldNormalColor(splats: SplatMesh) { - splats.enableWorldToView = true; - splats.worldModifier = makeNormalColorModifier(splats.context.worldToView); - splats.updateGenerator(); -} diff --git a/src/procedural.ts b/src/procedural.ts new file mode 100644 index 0000000..c4f8b4f --- /dev/null +++ b/src/procedural.ts @@ -0,0 +1,566 @@ +import * as THREE from "three"; +import { Splat } from "./Splat"; +import { + DefaultSplatEncoding, + type ResizableSplatEncoder, + type SplatEncoder, +} from "./encoding/encoder"; + +/** + * Method for constructing a Splat instance from a factory function. + * The number of splats is not fixed up-front. + * @param factory The factory to use to generate the splat properties + * @param options + * @returns The Splat instance + */ +function construct( + factory: (splatEncoder: ResizableSplatEncoder) => void, + options?: { + /** + * The splat encoder factory to use + */ + splatEncoder?: ResizableSplatEncoder | (() => ResizableSplatEncoder); + }, +): Splat { + const splatEncoderFactory = + options?.splatEncoder ?? DefaultSplatEncoding.createSplatEncoder; + const splatEncoder = + typeof splatEncoderFactory === "function" + ? splatEncoderFactory() + : splatEncoderFactory; + + factory(splatEncoder as ResizableSplatEncoder); + + return new Splat(splatEncoder.close()); +} + +/** + * Method for constructing a Splat instance with a fixed amount of splats using a factory function. + * @param numSplats The number of splats the resulting Splat instance should have + * @param factory The factory to use to generate the splat properties + * @param options + * @returns The Splat instance + */ +function constructFixed( + numSplats: number, + factory: (splatEncoder: SplatEncoder, numSplats: number) => void, + options?: { + /** + * The splat encoder factory to use. + */ + splatEncoder?: SplatEncoder | (() => SplatEncoder); + /** + * The number of spherical harmonics to allocate for each splat. + * @default 0 + */ + numSh?: number; + }, +): Splat { + const splatEncoderFactory = + options?.splatEncoder ?? DefaultSplatEncoding.createSplatEncoder; + const splatEncoder = + typeof splatEncoderFactory === "function" + ? splatEncoderFactory() + : splatEncoderFactory; + + const numSh = options?.numSh ?? 0; + splatEncoder.allocate(numSplats, numSh); + factory(splatEncoder as ResizableSplatEncoder, numSplats); + + return new Splat(splatEncoder.close()); +} + +export function constructGrid({ + // PackedSplats object to add splats to + splatEncoder, + // min and max box extents of the grid + extents, + // step size along each grid axis + stepSize = 1, + // spherical radius of each Gsplat + pointRadius = 0.01, + // relative size of the "shadow copy" of each Gsplat placed behind it + pointShadowScale = 2.0, + // Gsplat opacity + opacity = 1.0, + // Gsplat color (THREE.Color) or function to set color for position: + // ((THREE.Color, THREE.Vector3) => void) (default: RGB-modulated grid) + color, +}: { + splatEncoder: SplatEncoder; + extents: THREE.Box3; + stepSize?: number; + pointRadius?: number; + pointShadowScale?: number; + opacity?: number; + color?: THREE.Color | ((color: THREE.Color, point: THREE.Vector3) => void); +}) { + const EPSILON = 1.0e-6; + const center = new THREE.Vector3(); + if (color == null) { + color = (color, point) => + color.set( + 0.55 + 0.45 * Math.cos(point.x * 1), + 0.55 + 0.45 * Math.cos(point.y * 1), + 0.55 + 0.45 * Math.cos(point.z * 1), + ); + } + const pointColor = new THREE.Color(); + let i = 0; + for (let z = extents.min.z; z < extents.max.z + EPSILON; z += stepSize) { + for (let y = extents.min.y; y < extents.max.y + EPSILON; y += stepSize) { + for (let x = extents.min.x; x < extents.max.x + EPSILON; x += stepSize) { + center.set(x, y, z); + for (let layer = 0; layer < 2; ++layer) { + const scale = pointRadius * (layer ? 1 : pointShadowScale); + if (!layer) { + pointColor.setScalar(0.0); + } else if (typeof color === "function") { + color(pointColor, center); + } else { + pointColor.copy(color); + } + splatEncoder.setSplat( + i++, + x, + y, + z, + scale, + scale, + scale, + 0, + 0, + 0, + 1, + opacity, + pointColor.r, + pointColor.g, + pointColor.b, + ); + } + } + } + } +} + +export function constructAxes({ + // PackedSplats object to add splats to + splatEncoder, + // scale (Gsplat scale along axis) + scale = 0.25, + // radius of the axes (Gsplat scale orthogonal to axis) + axisRadius = 0.0075, + // relative size of the "shadow copy" of each Gsplat placed behind it + axisShadowScale = 2.0, + // origins of the axes (default single axis at origin) + origins = [new THREE.Vector3()], +}: { + splatEncoder: ResizableSplatEncoder; + scale?: number; + axisRadius?: number; + axisShadowScale?: number; + origins?: THREE.Vector3[]; +}) { + const center = new THREE.Vector3(); + const scales = new THREE.Vector3(); + const quaternion = new THREE.Quaternion(0, 0, 0, 1); + const color = new THREE.Color(); + const opacity = 1.0; + for (const origin of origins) { + for (let axis = 0; axis < 3; ++axis) { + center.set( + origin.x + (axis === 0 ? scale : 0), + origin.y + (axis === 1 ? scale : 0), + origin.z + (axis === 2 ? scale : 0), + ); + for (let layer = 0; layer < 2; ++layer) { + scales.set( + (axis === 0 ? scale : axisRadius) * (layer ? 1 : axisShadowScale), + (axis === 1 ? scale : axisRadius) * (layer ? 1 : axisShadowScale), + (axis === 2 ? scale : axisRadius) * (layer ? 1 : axisShadowScale), + ); + color.setRGB( + layer === 0 ? 0.0 : axis === 0 ? 1.0 : 0.0, + layer === 0 ? 0.0 : axis === 1 ? 1.0 : 0.0, + layer === 0 ? 0.0 : axis === 2 ? 1.0 : 0.0, + ); + splatEncoder.pushSplat( + center.x, + center.y, + center.z, + scales.x, + scales.y, + scales.z, + quaternion.x, + quaternion.y, + quaternion.z, + quaternion.w, + opacity, + color.r, + color.g, + color.b, + ); + } + } + } +} + +export function constructSpherePoints({ + // PackedSplats object to add splats to + splatEncoder, + // center of the sphere (default: origin) + origin = new THREE.Vector3(), + // radius of the sphere + radius = 1.0, + // maximum depth of recursion for subdividing the sphere + // Warning: Gsplat count grows exponentially with depth + maxDepth = 3, + // filter function to apply to each point, for example to select + // points in a certain direction or other function ((THREE.Vector3) => boolean) + // (default: null) + filter = null, + // radius of each oriented Gsplat + pointRadius = 0.02, + // flatness of each oriented Gsplat + pointThickness = 0.001, + // color of each Gsplat (THREE.Color) or function to set color for point: + // ((THREE.Color, THREE.Vector3) => void) (default: white) + color = new THREE.Color(1, 1, 1), +}: { + splatEncoder: ResizableSplatEncoder; + origin?: THREE.Vector3; + radius?: number; + maxDepth?: number; + filter?: ((point: THREE.Vector3) => boolean) | null; + pointRadius?: number; + pointThickness?: number; + color?: THREE.Color | ((color: THREE.Color, point: THREE.Vector3) => void); +}) { + const pointsHash: { [key: string]: THREE.Vector3 } = {}; + + function addPoint(p: THREE.Vector3) { + if (filter && !filter(p)) { + return; + } + const key = `${p.x},${p.y},${p.z}`; + if (!pointsHash[key]) { + pointsHash[key] = p; + } + } + + function recurse( + depth: number, + p0: THREE.Vector3, + p1: THREE.Vector3, + p2: THREE.Vector3, + ) { + addPoint(p0); + addPoint(p1); + addPoint(p2); + if (depth >= maxDepth) { + return; + } + const p01 = new THREE.Vector3().addVectors(p0, p1).normalize(); + const p12 = new THREE.Vector3().addVectors(p1, p2).normalize(); + const p20 = new THREE.Vector3().addVectors(p2, p0).normalize(); + recurse(depth + 1, p0, p01, p20); + recurse(depth + 1, p01, p1, p12); + recurse(depth + 1, p20, p12, p2); + recurse(depth + 1, p01, p12, p20); + } + + for (const x of [-1, 1]) { + for (const y of [-1, 1]) { + for (const z of [-1, 1]) { + const p0 = new THREE.Vector3(x, 0, 0); + const p1 = new THREE.Vector3(0, y, 0); + const p2 = new THREE.Vector3(0, 0, z); + recurse(0, p0, p1, p2); + } + } + } + + const points = Object.values(pointsHash); + const scales = new THREE.Vector3(pointRadius, pointRadius, pointThickness); + const quaternion = new THREE.Quaternion(); + const pointColor = typeof color === "function" ? new THREE.Color() : color; + for (const point of points) { + quaternion.setFromUnitVectors(new THREE.Vector3(0, 0, -1), point); + if (typeof color === "function") { + color(pointColor, point); + } + point.multiplyScalar(radius); + point.add(origin); + splatEncoder.pushSplat( + point.x, + point.y, + point.z, + scales.x, + scales.y, + scales.z, + quaternion.x, + quaternion.y, + quaternion.z, + quaternion.w, + 1.0, + pointColor.r, + pointColor.g, + pointColor.b, + ); + } +} + +function fromText( + text: string, + options?: { + /** + * browser font to render text with + * @default Arial + */ + font?: string; + /** + * font size in pixels/Gsplats + * @default 32 + */ + fontSize?: number; + /** + * Individual Gsplat color (default: white) + * @default white + */ + color?: THREE.Color; + /** + * Gsplat radius + * @default 0.8 covers 1-unit spacing well + */ + dotRadius?: number; + /** + * text alignment, one of "left", "center", "right", "start", "end" + * @default start + */ + textAlign?: "left" | "center" | "right" | "start" | "end"; + /** + * line spacing multiplier, lines delimited by "\n" + * @default 1.0 + */ + lineHeight?: number; + /** + * Coordinate scale in object-space + * @default 1.0 + */ + objectScale?: number; + /** + * The splat encoder factory to use + */ + splatEncoder?: ResizableSplatEncoder | (() => ResizableSplatEncoder); + }, +): Splat { + const font = options?.font ?? "Arial"; + const fontSize = options?.fontSize ?? 32; + const color = options?.color ?? new THREE.Color(1, 1, 1); + const dotRadius = options?.dotRadius ?? 0.8; + const textAlign = options?.textAlign ?? "start"; + const lineHeight = options?.lineHeight ?? 1; + const objectScale = options?.objectScale ?? 1; + const lines = text.split("\n"); + const splatEncoderFactory = + options?.splatEncoder ?? DefaultSplatEncoding.createSplatEncoder; + const splatEncoder = + typeof splatEncoderFactory === "function" + ? splatEncoderFactory() + : splatEncoderFactory; + + const canvas = document.createElement("canvas"); + const ctx = canvas.getContext("2d"); + if (!ctx) { + throw new Error("Failed to create canvas context"); + } + + ctx.font = `${fontSize}px ${font}`; + ctx.textAlign = textAlign; + const metrics = ctx.measureText(""); + const fontHeight = + metrics.fontBoundingBoxAscent + metrics.fontBoundingBoxDescent; + + let minLeft = Number.POSITIVE_INFINITY; + let maxRight = Number.NEGATIVE_INFINITY; + let minTop = Number.POSITIVE_INFINITY; + let maxBottom = Number.NEGATIVE_INFINITY; + for (let line = 0; line < lines.length; ++line) { + const metrics = ctx.measureText(lines[line]); + const y = fontHeight * lineHeight * line; + minLeft = Math.min(minLeft, -metrics.actualBoundingBoxLeft); + maxRight = Math.max(maxRight, metrics.actualBoundingBoxRight); + minTop = Math.min(minTop, y - metrics.actualBoundingBoxAscent); + maxBottom = Math.max(maxBottom, y + metrics.actualBoundingBoxDescent); + } + const originLeft = Math.floor(minLeft); + const originTop = Math.floor(minTop); + const width = Math.ceil(maxRight) - originLeft; + const height = Math.ceil(maxBottom) - originTop; + canvas.width = width; + canvas.height = height; + + ctx.font = `${fontSize}px ${font}`; + ctx.textAlign = textAlign; + ctx.textBaseline = "alphabetic"; + ctx.fillStyle = "#FFFFFF"; + for (let i = 0; i < lines.length; ++i) { + const y = fontHeight * lineHeight * i - originTop; + ctx.fillText(lines[i], -originLeft, y); + } + + const imageData = ctx.getImageData(0, 0, width, height); + const rgba = new Uint8Array(imageData.data.buffer); + const center = new THREE.Vector3(); + const scales = new THREE.Vector3().setScalar(dotRadius * objectScale); + const quaternion = new THREE.Quaternion(0, 0, 0, 1); + + let offset = 0; + for (let y = 0; y < height; ++y) { + for (let x = 0; x < width; ++x) { + const a = rgba[offset + 3]; + if (a > 0) { + const opacity = a / 255; + center.set(x - 0.5 * (width - 1), 0.5 * (height - 1) - y, 0); + center.multiplyScalar(objectScale); + splatEncoder.pushSplat( + center.x, + center.y, + center.z, + scales.x, + scales.y, + scales.z, + quaternion.x, + quaternion.y, + quaternion.z, + quaternion.w, + opacity, + color.r, + color.g, + color.b, + ); + } + offset += 4; + } + } + + return new Splat(splatEncoder.close()); +} + +type FromImageOptions = { + /** + * Radius of each Gsplat, default covers 1-unit spacing well + * @default 0.8 + */ + dotRadius?: number; + /** + * Subsampling factor for the image. Higher values reduce resolution, + * for example 2 will halve the width and height by averaging + * @default 1 + */ + subXY?: number; + /** + * The splat encoder factory to use + */ + splatEncoder?: ResizableSplatEncoder | (() => ResizableSplatEncoder); +}; + +function fromImage( + img: HTMLImageElement, + options?: FromImageOptions, +): Splat { + const dotRadius = options?.dotRadius ?? 0.8; + const subXY = Math.max(1, Math.floor(options?.subXY ?? 1)); + const splatEncoderFactory = + options?.splatEncoder ?? DefaultSplatEncoding.createSplatEncoder; + const splatEncoder = + typeof splatEncoderFactory === "function" + ? splatEncoderFactory() + : splatEncoderFactory; + + const { width, height } = img; + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + const ctx = canvas.getContext("2d"); + if (!ctx) { + throw new Error("Failed to create canvas context"); + } + ctx.imageSmoothingEnabled = true; + ctx.imageSmoothingQuality = "high"; + const destWidth = Math.round(width / subXY); + const destHeight = Math.round(height / subXY); + ctx.drawImage(img, 0, 0, destWidth, destHeight); + + const imageData = ctx.getImageData(0, 0, destWidth, destHeight); + const rgba = new Uint8Array(imageData.data.buffer); + + const center = new THREE.Vector3(); + const scales = new THREE.Vector3().setScalar(dotRadius); + const quaternion = new THREE.Quaternion(0, 0, 0, 1); + const rgb = new THREE.Color(); + + let index = 0; + for (let y = 0; y < destHeight; ++y) { + for (let x = 0; x < destWidth; ++x) { + const offset = index * 4; + const a = rgba[offset + 3]; + if (a > 0) { + const opacity = a / 255; + rgb.set( + rgba[offset + 0] / 255, + rgba[offset + 1] / 255, + rgba[offset + 2] / 255, + ); + center.set(x - 0.5 * (destWidth - 1), 0.5 * (destHeight - 1) - y, 0); + scales.setScalar(dotRadius); + quaternion.set(0, 0, 0, 1); + splatEncoder.pushSplat( + center.x, + center.y, + center.z, + scales.x, + scales.y, + scales.z, + quaternion.x, + quaternion.y, + quaternion.z, + quaternion.w, + opacity, + rgb.r, + rgb.g, + rgb.b, + ); + } + index += 1; + } + } + + return new Splat(splatEncoder.close()); +} + +async function fromImageUrl( + url: string, + options?: FromImageOptions, +): Promise { + const img = new Image(); + img.crossOrigin = "anonymous"; + const loadPromise = new Promise((resolve, reject) => { + img.onerror = reject; + img.onload = resolve; + }); + img.src = url; + + await loadPromise; + + return fromImage(img, options); +} + +// @ts-ignore +const SplatClass = Splat as Record; + +SplatClass.fromText = fromText; +SplatClass.fromImage = fromImage; +SplatClass.fromImageUrl = fromImageUrl; +SplatClass.construct = construct; +SplatClass.constructFixed = constructFixed; diff --git a/src/raycast.ts b/src/raycast.ts new file mode 100644 index 0000000..fa27fcd --- /dev/null +++ b/src/raycast.ts @@ -0,0 +1,84 @@ +import init_wasm, { raycast_splats } from "spark-internal-rs"; +import * as THREE from "three"; +import type { Splat } from "./Splat"; +import { LN_SCALE_MAX, LN_SCALE_MIN } from "./defines"; +import { PackedSplats } from "./encoding/PackedSplats"; + +export function simpleRaycastMethod( + splat: Splat, + raycaster: THREE.Raycaster, + intersects: THREE.Intersection[], +) { + // At this point the ray intersects the bounding sphere. + // Simply return the center of the Splat. + const point = splat.getWorldPosition(new THREE.Vector3()); + intersects.push({ + distance: point.distanceTo(raycaster.ray.origin), + point, + object: splat, + }); +} + +function preciseRaycastMethod( + splat: Splat, + raycaster: THREE.Raycaster, + intersects: THREE.Intersection[], +) { + const splatData = splat.splatData; + if (!(splatData instanceof PackedSplats)) { + throw new Error("Precise raycasting requires PackedSplats encoding"); + } + + const packedSplats = splatData as PackedSplats; + + const { near, far, ray } = raycaster; + const worldToMesh = splat.matrixWorld.clone().invert(); + const worldToMeshRot = new THREE.Matrix3().setFromMatrix4(worldToMesh); + const origin = ray.origin.clone().applyMatrix4(worldToMesh); + const direction = ray.direction.clone().applyMatrix3(worldToMeshRot); + + const RAYCAST_ELLIPSOID = true; + const distances = raycast_splats( + origin.x, + origin.y, + origin.z, + direction.x, + direction.y, + direction.z, + near, + far, + packedSplats.numSplats, + packedSplats.packedArray, + RAYCAST_ELLIPSOID, + packedSplats.splatEncoding?.lnScaleMin ?? LN_SCALE_MIN, + packedSplats.splatEncoding?.lnScaleMax ?? LN_SCALE_MAX, + ); + + for (const distance of distances) { + const point = ray.direction + .clone() + .multiplyScalar(distance) + .add(ray.origin); + intersects.push({ + distance, + point, + object: splat, + }); + } +} + +let wasmInitialized = false; +let wasmInitializing: ReturnType | null = null; + +export async function createPreciseRaycastMethod() { + // Lazy-init wasm + if (!wasmInitialized) { + if (!wasmInitializing) { + wasmInitializing = init_wasm(); + } + await wasmInitializing; + wasmInitialized = true; + } + + return preciseRaycastMethod; +} diff --git a/src/shaders.ts b/src/shaders.ts index 8e03d7c..8cea823 100644 --- a/src/shaders.ts +++ b/src/shaders.ts @@ -1,19 +1,28 @@ import * as THREE from "three"; +import extendedSplat from "./shaders/extendedSplat.glsl"; +import identityVertex from "./shaders/identityVertex.glsl"; +import packedSplat from "./shaders/packedSplat.glsl"; import splatDefines from "./shaders/splatDefines.glsl"; +import splatDistanceFragment from "./shaders/splatDistanceFragment.glsl"; import splatFragment from "./shaders/splatFragment.glsl"; import splatVertex from "./shaders/splatVertex.glsl"; -let shaders: Record | null = null; +let shaderChunksInitialized = false; +const shaders = { + splatVertex, + splatFragment, + identityVertex, + splatDistanceFragment, +} as const; -export function getShaders(): Record { - if (!shaders) { - // @ts-ignore - THREE.ShaderChunk.splatDefines = splatDefines; - shaders = { - splatVertex, - splatFragment, - }; +export function getShaders() { + if (!shaderChunksInitialized) { + const shaderChunks = THREE.ShaderChunk as Record; + shaderChunks.splatDefines = splatDefines; + shaderChunks.packedSplat = packedSplat; + shaderChunks.extendedSplat = extendedSplat; + shaderChunksInitialized = true; } return shaders; } diff --git a/src/shaders/computeUvec4.glsl b/src/shaders/computeUvec4.glsl deleted file mode 100644 index 65d54ef..0000000 --- a/src/shaders/computeUvec4.glsl +++ /dev/null @@ -1,36 +0,0 @@ -precision highp float; -precision highp int; -precision highp sampler2D; -precision highp usampler2D; -precision highp isampler2D; -precision highp sampler2DArray; -precision highp usampler2DArray; -precision highp isampler2DArray; -precision highp sampler3D; -precision highp usampler3D; -precision highp isampler3D; - -#include - -uniform uint targetLayer; -uniform int targetBase; -uniform int targetCount; - -out uvec4 target; - -{{ GLOBALS }} - -void produceSplat(int index) { - {{ STATEMENTS }} -} - -void main() { - int targetIndex = int(targetLayer << SPLAT_TEX_LAYER_BITS) + int(uint(gl_FragCoord.y) << SPLAT_TEX_WIDTH_BITS) + int(gl_FragCoord.x); - int index = targetIndex - targetBase; - - if ((index >= 0) && (index < targetCount)) { - produceSplat(index); - } else { - target = uvec4(0u, 0u, 0u, 0u); - } -} diff --git a/src/shaders/computeVec4.glsl b/src/shaders/computeVec4.glsl deleted file mode 100644 index 471af7c..0000000 --- a/src/shaders/computeVec4.glsl +++ /dev/null @@ -1,36 +0,0 @@ -precision highp float; -precision highp int; -precision highp sampler2D; -precision highp usampler2D; -precision highp isampler2D; -precision highp sampler2DArray; -precision highp usampler2DArray; -precision highp isampler2DArray; -precision highp sampler3D; -precision highp usampler3D; -precision highp isampler3D; - -#include - -uniform uint targetLayer; -uniform int targetBase; -uniform int targetCount; - -out vec4 target; - -{{ GLOBALS }} - -void computeReadback(int index) { - {{ STATEMENTS }} -} - -void main() { - int targetIndex = int(targetLayer << SPLAT_TEX_LAYER_BITS) + int(uint(gl_FragCoord.y) << SPLAT_TEX_WIDTH_BITS) + int(gl_FragCoord.x); - int index = targetIndex - targetBase; - - if ((index >= 0) && (index < targetCount)) { - computeReadback(index); - } else { - target = vec4(0.0, 0.0, 0.0, 0.0); - } -} diff --git a/src/shaders/extendedSplat.glsl b/src/shaders/extendedSplat.glsl new file mode 100644 index 0000000..c4e3195 --- /dev/null +++ b/src/shaders/extendedSplat.glsl @@ -0,0 +1,131 @@ +#ifdef USE_EXTENDED_SPLAT + +uniform usampler2DArray splatTexture1; +uniform usampler2DArray splatTexture2; +uniform usampler2DArray shTexture; + +const float LN_SCALE_MIN = -12.0; +const float LN_SCALE_MAX = 9.0; + +const uint SPLAT_TEX_WIDTH_BITS = 11u; +const uint SPLAT_TEX_HEIGHT_BITS = 11u; +const uint SPLAT_TEX_DEPTH_BITS = 11u; +const uint SPLAT_TEX_LAYER_BITS = SPLAT_TEX_WIDTH_BITS + SPLAT_TEX_HEIGHT_BITS; + +const uint SPLAT_TEX_WIDTH = 1u << SPLAT_TEX_WIDTH_BITS; +const uint SPLAT_TEX_HEIGHT = 1u << SPLAT_TEX_HEIGHT_BITS; +const uint SPLAT_TEX_DEPTH = 1u << SPLAT_TEX_DEPTH_BITS; + +const uint SPLAT_TEX_WIDTH_MASK = SPLAT_TEX_WIDTH - 1u; +const uint SPLAT_TEX_HEIGHT_MASK = SPLAT_TEX_HEIGHT - 1u; +const uint SPLAT_TEX_DEPTH_MASK = SPLAT_TEX_DEPTH - 1u; + +const uint F16_INF = 0x7c00u; + +// Decode a 24‐bit encoded uint into a quaternion (vec4) using the folded octahedral inverse. +vec4 decodeQuatOctXy88R8(uint encoded) { + // Extract the fields. + uint quantU = encoded & uint(0xFFu); // bits 0–7 + uint quantV = (encoded >> 8u) & uint(0xFFu); // bits 8–15 + uint angleInt = encoded >> 16u; // bits 16–23 + + // Recover u and v in [0,1], then map to [-1,1]. + float u_f = float(quantU) / 255.0; + float v_f = float(quantV) / 255.0; + vec2 f = vec2(u_f * 2.0 - 1.0, v_f * 2.0 - 1.0); + + vec3 axis = vec3(f.xy, 1.0 - abs(f.x) - abs(f.y)); + float t = max(-axis.z, 0.0); + axis.x += (axis.x >= 0.0) ? -t : t; + axis.y += (axis.y >= 0.0) ? -t : t; + axis = normalize(axis); + + // Decode the angle θ ∈ [0,π]. + float theta = (float(angleInt) / 255.0) * 3.14159265359; + float halfTheta = theta * 0.5; + float s = sin(halfTheta); + float w = cos(halfTheta); + + return vec4(axis * s, w); +} + + +void decodeExtendedSplat(uvec4 packed1, uvec4 packed2, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba) { + center = uintBitsToFloat(packed1.xyz); + + scales = vec3((uvec3(packed1.w) >> uvec3(0u, 10u, 20u)) & 1023u); + float lnScaleScale = (LN_SCALE_MAX - LN_SCALE_MIN) / 1023.0; + scales = exp(LN_SCALE_MIN + scales * lnScaleScale); + + quaternion = decodeQuatOctXy88R8(packed2.x); + + rgba = vec4((uvec4(packed2.y) >> uvec4(0u, 8u, 16u, 24u)) & 255u) / 255.0; +} + +ivec3 splatTexCoord(uint index) { + uint x = index & SPLAT_TEX_WIDTH_MASK; + uint y = (index >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK; + uint z = index >> SPLAT_TEX_LAYER_BITS; + return ivec3(x, y, z); +} + +// Unpack a Gsplat from a uvec4 +void decodeExtendedSplatDefault(uint splatIndex, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba) { + ivec3 texCoord = splatTexCoord(splatIndex); + uvec4 packed1 = texelFetch(splatTexture1, texCoord, 0); + uvec4 packed2 = texelFetch(splatTexture2, texCoord, 0); + decodeExtendedSplat(packed1, packed2, center, scales, quaternion, rgba); +} + +#ifdef NUM_SH +vec4 unpackSint8(uint packed) { + return vec4((ivec4(packed) << ivec4(24u, 16u, 8u, 0u)) >> 24u) / 127.0; +} + +void decodePackedSphericalHarmonics(uint splatIndex, out vec3[3] sh1, out vec3[5] sh2, out vec3[7] sh3) { + ivec3 texCoord = splatTexCoord(splatIndex); + texCoord.x *= NUM_PACKED_SH; + + uvec4 packedA = texelFetch(shTexture, texCoord, 0); + vec4 a1 = unpackSint8(packedA.x); + vec4 a2 = unpackSint8(packedA.y); + vec4 a3 = unpackSint8(packedA.z); + vec4 a4 = unpackSint8(packedA.w); + + sh1[0] = a1.xyz; + sh1[1] = vec3(a1.w, a2.xy); + sh1[2] = vec3(a2.zw, a3.x); + +#if NUM_PACKED_SH > 1 + uvec4 packedB = texelFetch(shTexture, texCoord + ivec3(1, 0, 0), 0); + vec4 b1 = unpackSint8(packedB.x); + vec4 b2 = unpackSint8(packedB.y); + vec4 b3 = unpackSint8(packedB.z); + vec4 b4 = unpackSint8(packedB.w); + + sh2[0] = vec3(a3.yzw); + sh2[1] = vec3(a4.xyz); + sh2[2] = vec3(a4.w, b1.xy); + sh2[3] = vec3(b1.zw, b2.x); + sh2[4] = vec3(b2.yzw); + +#if NUM_PACKED_SH > 2 + uvec4 packedC = texelFetch(shTexture, texCoord + ivec3(2, 0, 0), 0); + vec4 c1 = unpackSint8(packedC.x); + vec4 c2 = unpackSint8(packedC.y); + vec4 c3 = unpackSint8(packedC.z); + vec4 c4 = unpackSint8(packedC.w); + + sh3[0] = vec3(b3.xyz); + sh3[1] = vec3(b3.w, b4.xy); + sh3[2] = vec3(b4.zw, c1.x); + sh3[3] = vec3(c1.yzw); + sh3[4] = vec3(c2.xyz); + sh3[5] = vec3(c2.w, c3.xy); + sh3[6] = vec3(c3.zw, c4.x); +#endif +#endif +} +#endif + +#endif \ No newline at end of file diff --git a/src/shaders/identityVertex.glsl b/src/shaders/identityVertex.glsl new file mode 100644 index 0000000..3e221f2 --- /dev/null +++ b/src/shaders/identityVertex.glsl @@ -0,0 +1,7 @@ +precision highp float; + +in vec3 position; + +void main() { + gl_Position = vec4(position.xy, 0.0, 1.0); +} \ No newline at end of file diff --git a/src/shaders/packedSplat.glsl b/src/shaders/packedSplat.glsl new file mode 100644 index 0000000..d1ed1f5 --- /dev/null +++ b/src/shaders/packedSplat.glsl @@ -0,0 +1,220 @@ +#ifdef USE_PACKED_SPLAT + +uniform usampler2DArray packedSplats; +uniform usampler2DArray packedShTexture; +uniform vec4 rgbMinMaxLnScaleMinMax; + +const float LN_SCALE_MIN = -12.0; +const float LN_SCALE_MAX = 9.0; + +const uint SPLAT_TEX_WIDTH_BITS = 11u; +const uint SPLAT_TEX_HEIGHT_BITS = 11u; +const uint SPLAT_TEX_DEPTH_BITS = 11u; +const uint SPLAT_TEX_LAYER_BITS = SPLAT_TEX_WIDTH_BITS + SPLAT_TEX_HEIGHT_BITS; + +const uint SPLAT_TEX_WIDTH = 1u << SPLAT_TEX_WIDTH_BITS; +const uint SPLAT_TEX_HEIGHT = 1u << SPLAT_TEX_HEIGHT_BITS; +const uint SPLAT_TEX_DEPTH = 1u << SPLAT_TEX_DEPTH_BITS; + +const uint SPLAT_TEX_WIDTH_MASK = SPLAT_TEX_WIDTH - 1u; +const uint SPLAT_TEX_HEIGHT_MASK = SPLAT_TEX_HEIGHT - 1u; +const uint SPLAT_TEX_DEPTH_MASK = SPLAT_TEX_DEPTH - 1u; + +const uint F16_INF = 0x7c00u; + +// Encode a quaternion (vec4) into a 24‐bit uint with folded octahedral mapping. +uint encodeQuatOctXy88R8(vec4 q) { + // Ensure minimal representation: flip if q.w is negative. + if (q.w < 0.0) { + q = -q; + } + // Compute rotation angle: θ = 2 * acos(q.w) ∈ [0,π] + float theta = 2.0 * acos(q.w); + float halfTheta = theta * 0.5; + float s = sin(halfTheta); + // Recover the rotation axis; use a default if nearly zero rotation. + vec3 axis = (abs(s) < 1e-6) ? vec3(1.0, 0.0, 0.0) : q.xyz / s; + + // --- Folded Octahedral Mapping (inline) --- + // Compute p = (axis.x, axis.y) / (|axis.x|+|axis.y|+|axis.z|) + float sum = abs(axis.x) + abs(axis.y) + abs(axis.z); + vec2 p = vec2(axis.x, axis.y) / sum; + // If axis.z < 0, fold the mapping. + if (axis.z < 0.0) { + float oldPx = p.x; + p.x = (1.0 - abs(p.y)) * (p.x >= 0.0 ? 1.0 : -1.0); + p.y = (1.0 - abs(oldPx)) * (p.y >= 0.0 ? 1.0 : -1.0); + } + // Remap from [-1,1] to [0,1] + float u_f = p.x * 0.5 + 0.5; + float v_f = p.y * 0.5 + 0.5; + // Quantize to 8 bits (0 to 255) + uint quantU = uint(clamp(round(u_f * 255.0), 0.0, 255.0)); + uint quantV = uint(clamp(round(v_f * 255.0), 0.0, 255.0)); + + // --- Angle Quantization --- + // Quantize θ ∈ [0,π] to 8 bits (0 to 255) + uint angleInt = uint(clamp(round((theta / 3.14159265359) * 255.0), 0.0, 255.0)); + + // Pack bits: bits [0–7]: quantU, [8–15]: quantV, [16–23]: angleInt. + return (angleInt << 16u) | (quantV << 8u) | quantU; +} + +// Decode a 24‐bit encoded uint into a quaternion (vec4) using the folded octahedral inverse. +vec4 decodeQuatOctXy88R8(uint encoded) { + // Extract the fields. + uint quantU = encoded & uint(0xFFu); // bits 0–7 + uint quantV = (encoded >> 8u) & uint(0xFFu); // bits 8–15 + uint angleInt = encoded >> 16u; // bits 16–23 + + // Recover u and v in [0,1], then map to [-1,1]. + float u_f = float(quantU) / 255.0; + float v_f = float(quantV) / 255.0; + vec2 f = vec2(u_f * 2.0 - 1.0, v_f * 2.0 - 1.0); + + vec3 axis = vec3(f.xy, 1.0 - abs(f.x) - abs(f.y)); + float t = max(-axis.z, 0.0); + axis.x += (axis.x >= 0.0) ? -t : t; + axis.y += (axis.y >= 0.0) ? -t : t; + axis = normalize(axis); + + // Decode the angle θ ∈ [0,π]. + float theta = (float(angleInt) / 255.0) * 3.14159265359; + float halfTheta = theta * 0.5; + float s = sin(halfTheta); + float w = cos(halfTheta); + + return vec4(axis * s, w); +} + +// Pack a Gsplat into a uvec4 +uvec4 encodePackedSplat( + vec3 center, vec3 scales, vec4 quaternion, vec4 rgba, vec4 rgbMinMaxLnScaleMinMax +) { + float rgbMin = rgbMinMaxLnScaleMinMax.x; + float rgbMax = rgbMinMaxLnScaleMinMax.y; + vec3 encRgb = (rgba.rgb - vec3(rgbMin)) / (rgbMax - rgbMin); + uvec4 uRgba = uvec4(round(clamp(vec4(encRgb, rgba.a) * 255.0, 0.0, 255.0))); + + uint uQuat = encodeQuatOctXy88R8(quaternion); + uvec3 uQuat3 = uvec3(uQuat & 0xffu, (uQuat >> 8u) & 0xffu, (uQuat >> 16u) & 0xffu); + + // Encode scales in three uint8s, where 0=>0.0 and 1..=255 stores log scale + float lnScaleMin = rgbMinMaxLnScaleMinMax.z; + float lnScaleMax = rgbMinMaxLnScaleMinMax.w; + float lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); + uvec3 uScales = uvec3( + (scales.x == 0.0) ? 0u : uint(round(clamp((log(scales.x) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u, + (scales.y == 0.0) ? 0u : uint(round(clamp((log(scales.y) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u, + (scales.z == 0.0) ? 0u : uint(round(clamp((log(scales.z) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u + ); + + // Pack it all into 4 x uint32 + uint word0 = uRgba.r | (uRgba.g << 8u) | (uRgba.b << 16u) | (uRgba.a << 24u); + uint word1 = packHalf2x16(center.xy); + uint word2 = packHalf2x16(vec2(center.z, 0.0)) | (uQuat3.x << 16u) | (uQuat3.y << 24u); + uint word3 = uScales.x | (uScales.y << 8u) | (uScales.z << 16u) | (uQuat3.z << 24u); + return uvec4(word0, word1, word2, word3); +} + +// Pack a Gsplat into a uvec4 +uvec4 encodePackedSplatDefault(vec3 center, vec3 scales, vec4 quaternion, vec4 rgba) { + return encodePackedSplat(center, scales, quaternion, rgba, vec4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX)); +} + +void decodePackedSplat(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba, vec4 rgbMinMaxLnScaleMinMax) { + uint word0 = packed.x, word1 = packed.y, word2 = packed.z, word3 = packed.w; + + uvec4 uRgba = uvec4(word0 & 0xffu, (word0 >> 8u) & 0xffu, (word0 >> 16u) & 0xffu, (word0 >> 24u) & 0xffu); + float rgbMin = rgbMinMaxLnScaleMinMax.x; + float rgbMax = rgbMinMaxLnScaleMinMax.y; + rgba = (vec4(uRgba) / 255.0); + rgba.rgb = rgba.rgb * (rgbMax - rgbMin) + rgbMin; + + center = vec4( + unpackHalf2x16(word1), + unpackHalf2x16(word2 & 0xffffu) + ).xyz; + + uvec3 uScales = uvec3(word3 & 0xffu, (word3 >> 8u) & 0xffu, (word3 >> 16u) & 0xffu); + float lnScaleMin = rgbMinMaxLnScaleMinMax.z; + float lnScaleMax = rgbMinMaxLnScaleMinMax.w; + float lnScaleScale = (lnScaleMax - lnScaleMin) / 254.0; + scales = vec3( + (uScales.x == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.x - 1u) * lnScaleScale), + (uScales.y == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.y - 1u) * lnScaleScale), + (uScales.z == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.z - 1u) * lnScaleScale) + ); + + + uint uQuat = ((word2 >> 16u) & 0xFFFFu) | ((word3 >> 8u) & 0xFF0000u); + quaternion = decodeQuatOctXy88R8(uQuat); +} + +ivec3 splatTexCoord(uint index) { + uint x = index & SPLAT_TEX_WIDTH_MASK; + uint y = (index >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK; + uint z = index >> SPLAT_TEX_LAYER_BITS; + return ivec3(x, y, z); +} + +// Unpack a Gsplat from a uvec4 +void decodePackedSplatDefault(uint splatIndex, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba) { + ivec3 texCoord = splatTexCoord(splatIndex); + uvec4 packed = texelFetch(packedSplats, texCoord, 0); + decodePackedSplat(packed, center, scales, quaternion, rgba, vec4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX)); +} + +// Unpack spherical harmonics +#ifdef NUM_SH +vec4 unpackSint8(uint packed) { + return vec4((ivec4(packed) << ivec4(24u, 16u, 8u, 0u)) >> 24u) / 127.0; +} + +void decodePackedSphericalHarmonics(uint splatIndex, out vec3[3] sh1, out vec3[5] sh2, out vec3[7] sh3) { + ivec3 texCoord = splatTexCoord(splatIndex); + texCoord.x *= NUM_PACKED_SH; + + uvec4 packedA = texelFetch(packedShTexture, texCoord, 0); + vec4 a1 = unpackSint8(packedA.x); + vec4 a2 = unpackSint8(packedA.y); + vec4 a3 = unpackSint8(packedA.z); + vec4 a4 = unpackSint8(packedA.w); + + sh1[0] = a1.xyz; + sh1[1] = vec3(a1.w, a2.xy); + sh1[2] = vec3(a2.zw, a3.x); + +#if NUM_PACKED_SH > 1 + uvec4 packedB = texelFetch(packedShTexture, texCoord + ivec3(1, 0, 0), 0); + vec4 b1 = unpackSint8(packedB.x); + vec4 b2 = unpackSint8(packedB.y); + vec4 b3 = unpackSint8(packedB.z); + vec4 b4 = unpackSint8(packedB.w); + + sh2[0] = vec3(a3.yzw); + sh2[1] = vec3(a4.xyz); + sh2[2] = vec3(a4.w, b1.xy); + sh2[3] = vec3(b1.zw, b2.x); + sh2[4] = vec3(b2.yzw); + +#if NUM_PACKED_SH > 2 + uvec4 packedC = texelFetch(packedShTexture, texCoord + ivec3(2, 0, 0), 0); + vec4 c1 = unpackSint8(packedC.x); + vec4 c2 = unpackSint8(packedC.y); + vec4 c3 = unpackSint8(packedC.z); + vec4 c4 = unpackSint8(packedC.w); + + sh3[0] = vec3(b3.xyz); + sh3[1] = vec3(b3.w, b4.xy); + sh3[2] = vec3(b4.zw, c1.x); + sh3[3] = vec3(c1.yzw); + sh3[4] = vec3(c2.xyz); + sh3[5] = vec3(c2.w, c3.xy); + sh3[6] = vec3(c3.zw, c4.x); +#endif +#endif +} +#endif + +#endif \ No newline at end of file diff --git a/src/shaders/splatDefines.glsl b/src/shaders/splatDefines.glsl index 79e10c1..48b35f9 100644 --- a/src/shaders/splatDefines.glsl +++ b/src/shaders/splatDefines.glsl @@ -1,20 +1,3 @@ -const float LN_SCALE_MIN = -12.0; -const float LN_SCALE_MAX = 9.0; - -const uint SPLAT_TEX_WIDTH_BITS = 11u; -const uint SPLAT_TEX_HEIGHT_BITS = 11u; -const uint SPLAT_TEX_DEPTH_BITS = 11u; -const uint SPLAT_TEX_LAYER_BITS = SPLAT_TEX_WIDTH_BITS + SPLAT_TEX_HEIGHT_BITS; - -const uint SPLAT_TEX_WIDTH = 1u << SPLAT_TEX_WIDTH_BITS; -const uint SPLAT_TEX_HEIGHT = 1u << SPLAT_TEX_HEIGHT_BITS; -const uint SPLAT_TEX_DEPTH = 1u << SPLAT_TEX_DEPTH_BITS; - -const uint SPLAT_TEX_WIDTH_MASK = SPLAT_TEX_WIDTH - 1u; -const uint SPLAT_TEX_HEIGHT_MASK = SPLAT_TEX_HEIGHT - 1u; -const uint SPLAT_TEX_DEPTH_MASK = SPLAT_TEX_DEPTH - 1u; - -const uint F16_INF = 0x7c00u; const float PI = 3.1415926535897932384626433832795; const float INFINITY = 1.0 / 0.0; @@ -42,234 +25,6 @@ vec3 linearToSrgb(vec3 rgb) { return pow(rgb, vec3(1.0 / 2.2)); } -// uint encodeQuatXyz888(vec4 q) { -// // Encode quaternion in three int8s, flipping sign to remove ambiguity -// vec3 quat3 = (q.w < 0.0) ? -q.xyz : q.xyz; -// ivec3 iQuat3 = ivec3(round(clamp(quat3 * 127.0, -127.0, 127.0))); -// uvec3 uQuat3 = uvec3(iQuat3) & 0xffu; -// return (uQuat3.x << 16u) | (uQuat3.y << 8u) | uQuat3.z; -// } - -// vec4 decodeQuatXyz888(uint encoded) { -// ivec3 iQuat3 = ivec3( -// int(encoded << 24u) >> 24, -// int(encoded << 16u) >> 24, -// int(encoded << 8u) >> 24 -// ); -// vec4 quat = vec4(vec3(iQuat3) / 127.0, 0.0); -// quat.w = sqrt(max(0.0, 1.0 - dot(quat.xyz, quat.xyz))); -// return quat; -// } - -// Encode a quaternion (vec4) into a 24‐bit uint with folded octahedral mapping. -uint encodeQuatOctXy88R8(vec4 q) { - // Ensure minimal representation: flip if q.w is negative. - if (q.w < 0.0) { - q = -q; - } - // Compute rotation angle: θ = 2 * acos(q.w) ∈ [0,π] - float theta = 2.0 * acos(q.w); - float halfTheta = theta * 0.5; - float s = sin(halfTheta); - // Recover the rotation axis; use a default if nearly zero rotation. - vec3 axis = (abs(s) < 1e-6) ? vec3(1.0, 0.0, 0.0) : q.xyz / s; - - // --- Folded Octahedral Mapping (inline) --- - // Compute p = (axis.x, axis.y) / (|axis.x|+|axis.y|+|axis.z|) - float sum = abs(axis.x) + abs(axis.y) + abs(axis.z); - vec2 p = vec2(axis.x, axis.y) / sum; - // If axis.z < 0, fold the mapping. - if (axis.z < 0.0) { - float oldPx = p.x; - p.x = (1.0 - abs(p.y)) * (p.x >= 0.0 ? 1.0 : -1.0); - p.y = (1.0 - abs(oldPx)) * (p.y >= 0.0 ? 1.0 : -1.0); - } - // Remap from [-1,1] to [0,1] - float u_f = p.x * 0.5 + 0.5; - float v_f = p.y * 0.5 + 0.5; - // Quantize to 8 bits (0 to 255) - uint quantU = uint(clamp(round(u_f * 255.0), 0.0, 255.0)); - uint quantV = uint(clamp(round(v_f * 255.0), 0.0, 255.0)); - - // --- Angle Quantization --- - // Quantize θ ∈ [0,π] to 8 bits (0 to 255) - uint angleInt = uint(clamp(round((theta / 3.14159265359) * 255.0), 0.0, 255.0)); - - // Pack bits: bits [0–7]: quantU, [8–15]: quantV, [16–23]: angleInt. - return (angleInt << 16u) | (quantV << 8u) | quantU; -} - -// Decode a 24‐bit encoded uint into a quaternion (vec4) using the folded octahedral inverse. -vec4 decodeQuatOctXy88R8(uint encoded) { - // Extract the fields. - uint quantU = encoded & uint(0xFFu); // bits 0–7 - uint quantV = (encoded >> 8u) & uint(0xFFu); // bits 8–15 - uint angleInt = encoded >> 16u; // bits 16–23 - - // Recover u and v in [0,1], then map to [-1,1]. - float u_f = float(quantU) / 255.0; - float v_f = float(quantV) / 255.0; - vec2 f = vec2(u_f * 2.0 - 1.0, v_f * 2.0 - 1.0); - - vec3 axis = vec3(f.xy, 1.0 - abs(f.x) - abs(f.y)); - float t = max(-axis.z, 0.0); - axis.x += (axis.x >= 0.0) ? -t : t; - axis.y += (axis.y >= 0.0) ? -t : t; - axis = normalize(axis); - - // Decode the angle θ ∈ [0,π]. - float theta = (float(angleInt) / 255.0) * 3.14159265359; - float halfTheta = theta * 0.5; - float s = sin(halfTheta); - float w = cos(halfTheta); - - return vec4(axis * s, w); -} - -// // Encode a quaternion (vec4) into a 24‐bit uint by converting it to Euler angles. -// // We assume the quaternion is normalized. -// // Euler angles (roll, pitch, yaw) are assumed in radians in the range [-PI, PI]. -// // Each angle is normalized: value = (angle + PI) / (2*PI) and quantized to 8 bits. -// uint encodeQuatEulerXyz888(vec4 q) { -// // Compute roll (x), pitch (y) and yaw (z) using Tait–Bryan angles. -// float sinr_cosp = 2.0 * (q.w * q.x + q.y * q.z); -// float cosr_cosp = 1.0 - 2.0 * (q.x * q.x + q.y * q.y); -// float roll = atan(sinr_cosp, cosr_cosp); - -// float sinp = 2.0 * (q.w * q.y - q.z * q.x); -// float pitch = abs(sinp) >= 1.0 ? (sign(sinp) * 1.57079632679) : asin(sinp); - -// float siny_cosp = 2.0 * (q.w * q.z + q.x * q.y); -// float cosy_cosp = 1.0 - 2.0 * (q.y * q.y + q.z * q.z); -// float yaw = atan(siny_cosp, cosy_cosp); - -// // Normalize each angle from [-PI, PI] to [0, 1] -// float normRoll = (roll + 3.14159265359) / (2.0 * 3.14159265359); -// float normPitch = (pitch + 3.14159265359) / (2.0 * 3.14159265359); -// float normYaw = (yaw + 3.14159265359) / (2.0 * 3.14159265359); - -// // Quantize each normalized angle to 8 bits (0..255) -// uint rollQ = uint(round(normRoll * 255.0)); -// uint pitchQ = uint(round(normPitch * 255.0)); -// uint yawQ = uint(round(normYaw * 255.0)); - -// // Pack into a 24-bit uint: -// // Bits 0..7 : rollQ, -// // Bits 8..15 : pitchQ, -// // Bits 16..23 : yawQ. -// return (yawQ << 16u) | (pitchQ << 8u) | rollQ; -// } - -// // Decode a 24‐bit uint into a quaternion (vec4) by unpacking 8‐bit quantized Euler angles. -// // The Euler angles are assumed to be stored in the order: roll, pitch, yaw (each in [0,255]) corresponding to [-PI, PI]. -// // Convert the Euler angles to a quaternion using the Tait–Bryan (roll, pitch, yaw) formula. -// vec4 decodeQuatEulerXyz888(uint encoded) { -// // Unpack each 8-bit field. -// uint rollQ = encoded & 0xFFu; -// uint pitchQ = (encoded >> 8u) & 0xFFu; -// uint yawQ = (encoded >> 16u) & 0xFFu; - -// // Convert back to the [0,1] range. -// float normRoll = float(rollQ) / 255.0; -// float normPitch = float(pitchQ) / 255.0; -// float normYaw = float(yawQ) / 255.0; - -// // Map from [0,1] back to [-PI, PI]. -// float roll = normRoll * (2.0 * 3.14159265359) - 3.14159265359; -// float pitch = normPitch * (2.0 * 3.14159265359) - 3.14159265359; -// float yaw = normYaw * (2.0 * 3.14159265359) - 3.14159265359; - -// // Convert Euler angles (roll, pitch, yaw) to quaternion. -// float cr = cos(roll * 0.5); -// float sr = sin(roll * 0.5); -// float cp = cos(pitch * 0.5); -// float sp = sin(pitch * 0.5); -// float cy = cos(yaw * 0.5); -// float sy = sin(yaw * 0.5); - -// // Tait-Bryan (roll, pitch, yaw) to quaternion conversion. -// vec4 q; -// q.w = cr * cp * cy + sr * sp * sy; -// q.x = sr * cp * cy - cr * sp * sy; -// q.y = cr * sp * cy + sr * cp * sy; -// q.z = cr * cp * sy - sr * sp * cy; - -// return q; -// } - -// Pack a Gsplat into a uvec4 -uvec4 packSplatEncoding( - vec3 center, vec3 scales, vec4 quaternion, vec4 rgba, vec4 rgbMinMaxLnScaleMinMax -) { - float rgbMin = rgbMinMaxLnScaleMinMax.x; - float rgbMax = rgbMinMaxLnScaleMinMax.y; - vec3 encRgb = (rgba.rgb - vec3(rgbMin)) / (rgbMax - rgbMin); - uvec4 uRgba = uvec4(round(clamp(vec4(encRgb, rgba.a) * 255.0, 0.0, 255.0))); - - uint uQuat = encodeQuatOctXy88R8(quaternion); - // uint uQuat = encodeQuatXyz888(quaternion); - // uint uQuat = encodeQuatEulerXyz888(quaternion); - uvec3 uQuat3 = uvec3(uQuat & 0xffu, (uQuat >> 8u) & 0xffu, (uQuat >> 16u) & 0xffu); - - // Encode scales in three uint8s, where 0=>0.0 and 1..=255 stores log scale - float lnScaleMin = rgbMinMaxLnScaleMinMax.z; - float lnScaleMax = rgbMinMaxLnScaleMinMax.w; - float lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); - uvec3 uScales = uvec3( - (scales.x == 0.0) ? 0u : uint(round(clamp((log(scales.x) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u, - (scales.y == 0.0) ? 0u : uint(round(clamp((log(scales.y) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u, - (scales.z == 0.0) ? 0u : uint(round(clamp((log(scales.z) - lnScaleMin) * lnScaleScale, 0.0, 254.0))) + 1u - ); - - // Pack it all into 4 x uint32 - uint word0 = uRgba.r | (uRgba.g << 8u) | (uRgba.b << 16u) | (uRgba.a << 24u); - uint word1 = packHalf2x16(center.xy); - uint word2 = packHalf2x16(vec2(center.z, 0.0)) | (uQuat3.x << 16u) | (uQuat3.y << 24u); - uint word3 = uScales.x | (uScales.y << 8u) | (uScales.z << 16u) | (uQuat3.z << 24u); - return uvec4(word0, word1, word2, word3); -} - -// Pack a Gsplat into a uvec4 -uvec4 packSplat(vec3 center, vec3 scales, vec4 quaternion, vec4 rgba) { - return packSplatEncoding(center, scales, quaternion, rgba, vec4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX)); -} - -void unpackSplatEncoding(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba, vec4 rgbMinMaxLnScaleMinMax) { - uint word0 = packed.x, word1 = packed.y, word2 = packed.z, word3 = packed.w; - - uvec4 uRgba = uvec4(word0 & 0xffu, (word0 >> 8u) & 0xffu, (word0 >> 16u) & 0xffu, (word0 >> 24u) & 0xffu); - float rgbMin = rgbMinMaxLnScaleMinMax.x; - float rgbMax = rgbMinMaxLnScaleMinMax.y; - rgba = (vec4(uRgba) / 255.0); - rgba.rgb = rgba.rgb * (rgbMax - rgbMin) + rgbMin; - - center = vec4( - unpackHalf2x16(word1), - unpackHalf2x16(word2 & 0xffffu) - ).xyz; - - uvec3 uScales = uvec3(word3 & 0xffu, (word3 >> 8u) & 0xffu, (word3 >> 16u) & 0xffu); - float lnScaleMin = rgbMinMaxLnScaleMinMax.z; - float lnScaleMax = rgbMinMaxLnScaleMinMax.w; - float lnScaleScale = (lnScaleMax - lnScaleMin) / 254.0; - scales = vec3( - (uScales.x == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.x - 1u) * lnScaleScale), - (uScales.y == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.y - 1u) * lnScaleScale), - (uScales.z == 0u) ? 0.0 : exp(lnScaleMin + float(uScales.z - 1u) * lnScaleScale) - ); - - - uint uQuat = ((word2 >> 16u) & 0xFFFFu) | ((word3 >> 8u) & 0xFF0000u); - quaternion = decodeQuatOctXy88R8(uQuat); - // quaternion = decodeQuatXyz888(uQuat); - // quaternion = decodeQuatEulerXyz888(uQuat); -} - -// Unpack a Gsplat from a uvec4 -void unpackSplat(uvec4 packed, out vec3 center, out vec3 scales, out vec4 quaternion, out vec4 rgba) { - unpackSplatEncoding(packed, center, scales, quaternion, rgba, vec4(0.0, 1.0, LN_SCALE_MIN, LN_SCALE_MAX)); -} - // Rotate vector v by quaternion q vec3 quatVec(vec4 q, vec3 v) { // Rotate vector v by quaternion q @@ -287,6 +42,20 @@ vec4 quatQuat(vec4 q1, vec4 q2) { ); } +mat3 quaternionToMatrix(vec4 q) { + return mat3( + (1.0 - 2.0 * (q.y * q.y + q.z * q.z)), + (2.0 * (q.x * q.y + q.w * q.z)), + (2.0 * (q.x * q.z - q.w * q.y)), + (2.0 * (q.x * q.y - q.w * q.z)), + (1.0 - 2.0 * (q.x * q.x + q.z * q.z)), + (2.0 * (q.y * q.z + q.w * q.x)), + (2.0 * (q.x * q.z + q.w * q.y)), + (2.0 * (q.y * q.z - q.w * q.x)), + (1.0 - 2.0 * (q.x * q.x + q.y * q.y)) + ); +} + mat3 scaleQuaternionToMatrix(vec3 s, vec4 q) { // Compute the matrix of scaling by s then rotating by q return mat3( @@ -302,38 +71,42 @@ mat3 scaleQuaternionToMatrix(vec3 s, vec4 q) { ); } -// Spherical lerp between two quaternions -vec4 slerp(vec4 q1, vec4 q2, float t) { - // Compute the cosine of the angle between the two vectors - float cosHalfTheta = dot(q1, q2); - - // If q1=q2 or q1=-q2 then theta = 0 and we can return q1 - if (abs(cosHalfTheta) >= 0.999) { - return q1; - } - - // If q1 and q2 are more than 180 degrees apart, - // we need to negate one to get the shortest path - if (cosHalfTheta < 0.0) { - q2 = -q2; - cosHalfTheta = -cosHalfTheta; - } - - // Calculate temporary values - float halfTheta = acos(cosHalfTheta); - float sinHalfTheta = sqrt(1.0 - cosHalfTheta * cosHalfTheta); - - // Calculate the interpolation factors - float ratioA = sin((1.0 - t) * halfTheta) / sinHalfTheta; - float ratioB = sin(t * halfTheta) / sinHalfTheta; - - // Calculate the interpolated quaternion - return q1 * ratioA + q2 * ratioB; -} - -ivec3 splatTexCoord(int index) { - uint x = uint(index) & SPLAT_TEX_WIDTH_MASK; - uint y = (uint(index) >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK; - uint z = uint(index) >> SPLAT_TEX_LAYER_BITS; - return ivec3(x, y, z); -} +#ifdef NUM_SH +vec3 evaluateSH(vec3 viewDir, vec3 sh1[3], vec3 sh2[5], vec3 sh3[7]) { + vec3 sh1Rgb = sh1[0] * (-0.4886025 * viewDir.y) + + sh1[1] * (0.4886025 * viewDir.z) + + sh1[2] * (-0.4886025 * viewDir.x); + +#if NUM_SH == 1 + return sh1Rgb; +#else + + float xx = viewDir.x * viewDir.x; + float yy = viewDir.y * viewDir.y; + float zz = viewDir.z * viewDir.z; + float xy = viewDir.x * viewDir.y; + float yz = viewDir.y * viewDir.z; + float zx = viewDir.z * viewDir.x; + + vec3 sh2Rgb = sh2[0] * (1.0925484 * xy) + + sh2[1] * (-1.0925484 * yz) + + sh2[2] * (0.3153915 * (2.0 * zz - xx - yy)) + + sh2[3] * (-1.0925484 * zx) + + sh2[4] * (0.5462742 * (xx - yy)); + +#if NUM_SH == 2 + return sh1Rgb + sh2Rgb; +#else + vec3 sh3Rgb = sh3[0] * (-0.5900436 * viewDir.y * (3.0 * xx - yy)) + + sh3[1] * (2.8906114 * xy * viewDir.z) + + + sh3[2] * (-0.4570458 * viewDir.y * (4.0 * zz - xx - yy)) + + sh3[3] * (0.3731763 * viewDir.z * (2.0 * zz - 3.0 * xx - 3.0 * yy)) + + sh3[4] * (-0.4570458 * viewDir.x * (4.0 * zz - xx - yy)) + + sh3[5] * (1.4453057 * viewDir.z * (xx - yy)) + + sh3[6] * (-0.5900436 * viewDir.x * (xx - 3.0 * yy)); + + return sh1Rgb + sh2Rgb + sh3Rgb; +#endif +#endif +} +#endif \ No newline at end of file diff --git a/src/shaders/splatDistanceFragment.glsl b/src/shaders/splatDistanceFragment.glsl new file mode 100644 index 0000000..8fcf362 --- /dev/null +++ b/src/shaders/splatDistanceFragment.glsl @@ -0,0 +1,72 @@ +precision highp float; +precision highp int; +precision highp sampler2D; +precision highp usampler2D; +precision highp isampler2D; +precision highp sampler2DArray; +precision highp usampler2DArray; +precision highp isampler2DArray; +precision highp sampler3D; +precision highp usampler3D; +precision highp isampler3D; + +#include +#include +#include + +#define decodeSplat SPLAT_DECODE_FN + +uniform uint targetLayer; +uniform int targetBase; +uniform int targetCount; + +uniform bool sortRadial; +uniform float sortDepthBias; +uniform bool sort360; + +uniform mat4 splatModelViewMatrix; + +out vec4 target; + +float computeSort(vec3 splatCenter, bool sortRadial, float sortDepthBias, bool sort360) { + // FIXME: Check active flag? + float biasedDepth = dot(splatCenter, vec3(0, 0, -1)) + sortDepthBias; + if (!sort360 && (biasedDepth <= 0.0)) { + return INFINITY; + } + return sortRadial ? length(splatCenter) : biasedDepth; +} + +void main() { + int targetIndex = int(targetLayer << SPLAT_TEX_LAYER_BITS) + int(uint(gl_FragCoord.y) << SPLAT_TEX_WIDTH_BITS) + int(gl_FragCoord.x); + int index = (targetIndex - targetBase); + + if ((index >= 0) && (index < targetCount)) { + vec3 center, scales; + vec4 quaternion, rgba; + + // Compute distance +#ifdef SORT32 + decodeSplat(uint(index), center, scales, quaternion, rgba); + center = (splatModelViewMatrix * vec4(center, 1.0)).xyz; + float metric = computeSort(center, sortRadial, sortDepthBias, sort360); + + uint packed = floatBitsToUint(metric); +#else + decodeSplat(uint(index * 2), center, scales, quaternion, rgba); + center = (splatModelViewMatrix * vec4(center, 1.0)).xyz; + float metric1 = computeSort(center, sortRadial, sortDepthBias, sort360); + + decodeSplat(uint(index * 2 + 1), center, scales, quaternion, rgba); + center = (splatModelViewMatrix * vec4(center, 1.0)).xyz; + float metric2 = computeSort(center, sortRadial, sortDepthBias, sort360); + + uint packed = packHalf2x16(vec2(metric1, metric2)); +#endif + + uvec4 uTarget = uvec4(packed & 0xffu, (packed >> 8u) & 0xffu, (packed >> 16u) & 0xffu, (packed >> 24u) & 0xffu); + target = vec4(uTarget) / 255.0; + } else { + target = vec4(0); + } +} diff --git a/src/shaders/splatFragment.glsl b/src/shaders/splatFragment.glsl index 0112ef3..3bd4298 100644 --- a/src/shaders/splatFragment.glsl +++ b/src/shaders/splatFragment.glsl @@ -3,65 +3,26 @@ precision highp float; precision highp int; #include -#include -uniform float near; -uniform float far; uniform bool encodeLinear; -uniform float time; -uniform bool debugFlag; uniform float maxStdDev; uniform float minAlpha; uniform bool stochastic; -uniform bool disableFalloff; uniform float falloff; - -uniform bool splatTexEnable; -uniform sampler3D splatTexture; -uniform mat2 splatTexMul; -uniform vec2 splatTexAdd; -uniform float splatTexNear; -uniform float splatTexFar; -uniform float splatTexMid; +uniform float time; out vec4 fragColor; in vec4 vRgba; in vec2 vSplatUv; -in vec3 vNdc; flat in uint vSplatIndex; void main() { vec4 rgba = vRgba; float z = dot(vSplatUv, vSplatUv); - if (!splatTexEnable) { - if (z > (maxStdDev * maxStdDev)) { - discard; - } - } else { - vec2 uv = splatTexMul * vSplatUv + splatTexAdd; - float ndcZ = vNdc.z; - float depth = (2.0 * near * far) / (far + near - ndcZ * (far - near)); - float clampedFar = max(splatTexFar, splatTexNear); - float clampedDepth = clamp(depth, splatTexNear, clampedFar); - float logDepth = log2(clampedDepth + 1.0); - float logNear = log2(splatTexNear + 1.0); - float logFar = log2(clampedFar + 1.0); - - float texZ; - if (splatTexMid > 0.0) { - float clampedMid = clamp(splatTexMid, splatTexNear, clampedFar); - float logMid = log2(clampedMid + 1.0); - texZ = (clampedDepth <= clampedMid) ? - (0.5 * ((logDepth - logNear) / (logMid - logNear))) : - (0.5 * ((logDepth - logMid) / (logFar - logMid)) + 0.5); - } else { - texZ = (logDepth - logNear) / (logFar - logNear); - } - - vec4 modulate = texture(splatTexture, vec3(uv, 1.0 - texZ)); - rgba *= modulate; + if (z > (maxStdDev * maxStdDev)) { + discard; } rgba.a *= mix(1.0, exp(-0.5 * z), falloff); @@ -73,7 +34,7 @@ void main() { rgba.rgb = srgbToLinear(rgba.rgb); } - if (stochastic) { + #ifdef STOCHASTIC const bool STEADY = false; uint uTime = STEADY ? 0u : floatBitsToUint(time); uvec2 coord = uvec2(gl_FragCoord.xy); @@ -87,12 +48,11 @@ void main() { } else { discard; } - } else { + #else #ifdef PREMULTIPLIED_ALPHA fragColor = vec4(rgba.rgb * rgba.a, rgba.a); #else fragColor = rgba; #endif - } - #include + #endif } diff --git a/src/shaders/splatVertex.glsl b/src/shaders/splatVertex.glsl index 61331f2..7374f7c 100644 --- a/src/shaders/splatVertex.glsl +++ b/src/shaders/splatVertex.glsl @@ -4,44 +4,81 @@ precision highp int; precision highp usampler2DArray; #include -#include +#include +#include +#define decodeSplat SPLAT_DECODE_FN +#define decodeSplatSh SPLAT_SH_DECODE_FN + +#ifdef STOCHASTIC +#define splatIndex uint(gl_InstanceID) +#else attribute uint splatIndex; +#endif + +#ifdef USE_BATCHING +uniform highp sampler2D batchingTexture; +mat4 getBatchingMatrix( const in uint i ) { + int size = textureSize( batchingTexture, 0 ).x; + int j = int( i ) * 4; + int x = j % size; + int y = j / size; + vec4 v1 = texelFetch( batchingTexture, ivec2( x, y ), 0 ); + vec4 v2 = texelFetch( batchingTexture, ivec2( x + 1, y ), 0 ); + vec4 v3 = texelFetch( batchingTexture, ivec2( x + 2, y ), 0 ); + vec4 v4 = texelFetch( batchingTexture, ivec2( x + 3, y ), 0 ); + return mat4( v1, v2, v3, v4 ); +} +#endif out vec4 vRgba; out vec2 vSplatUv; -out vec3 vNdc; flat out uint vSplatIndex; +uniform float opacity; + uniform vec2 renderSize; uniform uint numSplats; uniform vec4 renderToViewQuat; -uniform vec3 renderToViewPos; uniform float maxStdDev; uniform float minPixelRadius; uniform float maxPixelRadius; -uniform float time; -uniform float deltaTime; -uniform bool debugFlag; uniform float minAlpha; -uniform bool stochastic; uniform bool enable2DGS; uniform float blurAmount; uniform float preBlurAmount; -uniform float focalDistance; -uniform float apertureAngle; uniform float clipXY; uniform float focalAdjustment; -uniform usampler2DArray packedSplats; -uniform vec4 rgbMinMaxLnScaleMinMax; +// Shader hooks +#ifdef HOOK_GLOBAL +{{HOOK_GLOBAL}} +#endif -#ifdef USE_LOGDEPTHBUF - bool isPerspectiveMatrix( mat4 m ) { - return m[ 2 ][ 3 ] == - 1.0; - } +#ifdef HOOK_UNIFORMS +{{HOOK_UNIFORMS}} +#endif + +#ifdef HOOK_OBJECT_MODIFIER +void _shader_hook_object_modifier(inout vec3 center, inout vec3 scales, inout vec4 quaternion, inout vec4 rgba) { + {{HOOK_OBJECT_MODIFIER}} +} +#endif + +#ifdef HOOK_WORLD_MODIFIER +void _shader_hook_world_modifier(inout vec3 center, inout vec3 scales, inout vec4 quaternion, inout vec4 rgba) { + {{HOOK_WORLD_MODIFIER}} +} #endif +#ifdef HOOK_SPLAT_COLOR +vec4 _shader_hook_splat_color(in vec3 center, in vec3 scales, in vec4 quaternion, inout vec4 rgba, in vec3 viewCenter) { + {{HOOK_SPLAT_COLOR}} + return rgba; +} +#endif + + void main() { // Default to outside the frustum so it's discarded if we return early gl_Position = vec4(0.0, 0.0, 2.0, 1.0); @@ -50,29 +87,36 @@ void main() { return; } - ivec3 texCoord; - if (stochastic) { - texCoord = ivec3( - uint(gl_InstanceID) & SPLAT_TEX_WIDTH_MASK, - (uint(gl_InstanceID) >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK, - (uint(gl_InstanceID) >> SPLAT_TEX_LAYER_BITS) - ); - } else { - if (splatIndex == 0xffffffffu) { - // Special value reserved for "no splat" - return; - } - texCoord = ivec3( - splatIndex & SPLAT_TEX_WIDTH_MASK, - (splatIndex >> SPLAT_TEX_WIDTH_BITS) & SPLAT_TEX_HEIGHT_MASK, - splatIndex >> SPLAT_TEX_LAYER_BITS - ); - } - uvec4 packed = texelFetch(packedSplats, texCoord, 0); - + // Decode Splat data vec3 center, scales; vec4 quaternion, rgba; - unpackSplatEncoding(packed, center, scales, quaternion, rgba, rgbMinMaxLnScaleMinMax); + uint sIndex = splatIndex & 0x3FFFFFu; + uint objectIndex = splatIndex >> 26u; + decodeSplat(sIndex, center, scales, quaternion, rgba); +#ifdef HOOK_OBJECT_MODIFIER + _shader_hook_object_modifier(center, scales, quaternion, rgba); +#endif + +#ifdef USE_BATCHING + mat4 splatModelMatrix = getBatchingMatrix(objectIndex); +#else + mat4 splatModelMatrix = modelMatrix; +#endif + mat4 splatViewMatrix = viewMatrix * splatModelMatrix; + + // Compute viewDir for sh evaluation + vec3 cameraInObjectSpace = (inverse(splatModelMatrix) * vec4(cameraPosition, 1.0)).xyz; + vec3 viewDir = normalize(center - cameraInObjectSpace); + + // Transform into world space + float modelScale = length(splatModelMatrix[0]); + center = (splatModelMatrix * vec4(center, 1.0)).xyz; + scales *= modelScale; + rgba.a *= opacity; + +#ifdef HOOK_WORLD_MODIFIER + _shader_hook_world_modifier(center, scales, quaternion, rgba); +#endif if (rgba.a < minAlpha) { return; @@ -83,7 +127,7 @@ void main() { } // Compute the view space center of the splat - vec3 viewCenter = quatVec(renderToViewQuat, center) + renderToViewPos; + vec3 viewCenter = (viewMatrix * vec4(center, 1.0)).xyz; // Discard splats behind the camera if (viewCenter.z >= 0.0) { @@ -105,10 +149,10 @@ void main() { } // Record the splat index for entropy - vSplatIndex = splatIndex; + vSplatIndex = sIndex; // Compute view space quaternion of splat - vec4 viewQuaternion = quatQuat(renderToViewQuat, quaternion); + mat3 viewRotation = mat3(splatViewMatrix) * (1.0/modelScale) * quaternionToMatrix(quaternion); if (enable2DGS && any(zeroScales)) { vRgba = rgba; @@ -123,9 +167,8 @@ void main() { offset = vec3(0.0, vSplatUv.xy * scales.yz); } - vec3 viewPos = viewCenter + quatVec(viewQuaternion, offset); + vec3 viewPos = viewCenter + viewRotation * offset; gl_Position = projectionMatrix * vec4(viewPos, 1.0); - vNdc = gl_Position.xyz / gl_Position.w; return; } @@ -133,7 +176,7 @@ void main() { vec3 ndcCenter = clipCenter.xyz / clipCenter.w; // Compute the 3D covariance matrix of the splat - mat3 RS = scaleQuaternionToMatrix(scales, viewQuaternion); + mat3 RS = matrixCompMult(viewRotation, mat3(vec3(scales.x), vec3(scales.y), vec3(scales.z))); mat3 cov3D = RS * transpose(RS); // Compute the Jacobian of the splat's projection at its center @@ -160,11 +203,6 @@ void main() { // Compute the 2D covariance by projecting the 3D covariance // and picking out the XY plane components. - // Keeping below because we may need it in the future - // for skinning deformations. - // mat3 W = transpose(mat3(viewMatrix)); - // mat3 T = W * J; - // mat3 cov2D = transpose(T) * cov3D * T; mat3 cov2D = transpose(J) * cov3D * J; float a = cov2D[0][0]; float d = cov2D[1][1]; @@ -174,21 +212,10 @@ void main() { a += preBlurAmount; d += preBlurAmount; - float fullBlurAmount = blurAmount; - if ((focalDistance > 0.0) && (apertureAngle > 0.0)) { - float focusRadius = maxPixelRadius; - if (viewCenter.z < 0.0) { - float focusBlur = abs((-viewCenter.z - focalDistance) / viewCenter.z); - float apertureRadius = focal.x * tan(0.5 * apertureAngle); - focusRadius = focusBlur * apertureRadius; - } - fullBlurAmount = clamp(sqr(focusRadius), blurAmount, sqr(maxPixelRadius)); - } - // Do convolution with a 0.5-pixel Gaussian for anti-aliasing: sqrt(0.3) ~= 0.5 float detOrig = a * d - b * b; - a += fullBlurAmount; - d += fullBlurAmount; + a += blurAmount; + d += blurAmount; float det = a * d - b * b; // Compute anti-aliasing intensity scaling factor @@ -218,9 +245,20 @@ void main() { vec2 ndcOffset = (2.0 / scaledRenderSize) * pixelOffset; vec3 ndc = vec3(ndcCenter.xy + ndcOffset, ndcCenter.z); + // Evaluate spherical harmonics + #if NUM_SH > 0 + vec3[3] sh1; + vec3[5] sh2; + vec3[7] sh3; + decodeSplatSh(splatIndex, sh1, sh2, sh3); + rgba.rgb += evaluateSH(viewDir, sh1, sh2, sh3); + #endif + +#ifdef HOOK_SPLAT_COLOR + rgba = _shader_hook_splat_color(center, scales, quaternion, rgba, viewCenter); +#endif + vRgba = rgba; vSplatUv = position.xy * maxStdDev; - vNdc = ndc; gl_Position = vec4(ndc.xy * clipCenter.w, clipCenter.zw); - #include } diff --git a/src/splatConstructors.ts b/src/splatConstructors.ts deleted file mode 100644 index ad91e53..0000000 --- a/src/splatConstructors.ts +++ /dev/null @@ -1,419 +0,0 @@ -import * as THREE from "three"; -import { PackedSplats } from "./PackedSplats"; -import { SplatMesh } from "./SplatMesh"; - -export function constructGrid({ - // PackedSplats object to add splats to - splats, - // min and max box extents of the grid - extents, - // step size along each grid axis - stepSize = 1, - // spherical radius of each Gsplat - pointRadius = 0.01, - // relative size of the "shadow copy" of each Gsplat placed behind it - pointShadowScale = 2.0, - // Gsplat opacity - opacity = 1.0, - // Gsplat color (THREE.Color) or function to set color for position: - // ((THREE.Color, THREE.Vector3) => void) (default: RGB-modulated grid) - color, -}: { - splats: PackedSplats; - extents: THREE.Box3; - stepSize?: number; - pointRadius?: number; - pointShadowScale?: number; - opacity?: number; - color?: THREE.Color | ((color: THREE.Color, point: THREE.Vector3) => void); -}) { - const EPSILON = 1.0e-6; - const center = new THREE.Vector3(); - const scales = new THREE.Vector3(); - const quaternion = new THREE.Quaternion(0, 0, 0, 1); - if (color == null) { - color = (color, point) => - color.set( - 0.55 + 0.45 * Math.cos(point.x * 1), - 0.55 + 0.45 * Math.cos(point.y * 1), - 0.55 + 0.45 * Math.cos(point.z * 1), - ); - } - const pointColor = new THREE.Color(); - for (let z = extents.min.z; z < extents.max.z + EPSILON; z += stepSize) { - for (let y = extents.min.y; y < extents.max.y + EPSILON; y += stepSize) { - for (let x = extents.min.x; x < extents.max.x + EPSILON; x += stepSize) { - center.set(x, y, z); - for (let layer = 0; layer < 2; ++layer) { - scales.setScalar(pointRadius * (layer ? 1 : pointShadowScale)); - if (!layer) { - pointColor.setScalar(0.0); - } else if (typeof color === "function") { - color(pointColor, center); - } else { - pointColor.copy(color); - } - splats.pushSplat(center, scales, quaternion, opacity, pointColor); - } - } - } - } -} - -export function constructAxes({ - // PackedSplats object to add splats to - splats, - // scale (Gsplat scale along axis) - scale = 0.25, - // radius of the axes (Gsplat scale orthogonal to axis) - axisRadius = 0.0075, - // relative size of the "shadow copy" of each Gsplat placed behind it - axisShadowScale = 2.0, - // origins of the axes (default single axis at origin) - origins = [new THREE.Vector3()], -}: { - splats: PackedSplats; - scale?: number; - axisRadius?: number; - axisShadowScale?: number; - origins?: THREE.Vector3[]; -}) { - const center = new THREE.Vector3(); - const scales = new THREE.Vector3(); - const quaternion = new THREE.Quaternion(0, 0, 0, 1); - const color = new THREE.Color(); - const opacity = 1.0; - for (const origin of origins) { - for (let axis = 0; axis < 3; ++axis) { - center.set( - origin.x + (axis === 0 ? scale : 0), - origin.y + (axis === 1 ? scale : 0), - origin.z + (axis === 2 ? scale : 0), - ); - for (let layer = 0; layer < 2; ++layer) { - scales.set( - (axis === 0 ? scale : axisRadius) * (layer ? 1 : axisShadowScale), - (axis === 1 ? scale : axisRadius) * (layer ? 1 : axisShadowScale), - (axis === 2 ? scale : axisRadius) * (layer ? 1 : axisShadowScale), - ); - color.setRGB( - layer === 0 ? 0.0 : axis === 0 ? 1.0 : 0.0, - layer === 0 ? 0.0 : axis === 1 ? 1.0 : 0.0, - layer === 0 ? 0.0 : axis === 2 ? 1.0 : 0.0, - ); - splats.pushSplat(center, scales, quaternion, opacity, color); - } - } - } -} - -export function constructSpherePoints({ - // PackedSplats object to add splats to - splats, - // center of the sphere (default: origin) - origin = new THREE.Vector3(), - // radius of the sphere - radius = 1.0, - // maximum depth of recursion for subdividing the sphere - // Warning: Gsplat count grows exponentially with depth - maxDepth = 3, - // filter function to apply to each point, for example to select - // points in a certain direction or other function ((THREE.Vector3) => boolean) - // (default: null) - filter = null, - // radius of each oriented Gsplat - pointRadius = 0.02, - // flatness of each oriented Gsplat - pointThickness = 0.001, - // color of each Gsplat (THREE.Color) or function to set color for point: - // ((THREE.Color, THREE.Vector3) => void) (default: white) - color = new THREE.Color(1, 1, 1), -}: { - splats: PackedSplats; - origin?: THREE.Vector3; - radius?: number; - maxDepth?: number; - filter?: ((point: THREE.Vector3) => boolean) | null; - pointRadius?: number; - pointThickness?: number; - color?: THREE.Color | ((color: THREE.Color, point: THREE.Vector3) => void); -}) { - const pointsHash: { [key: string]: THREE.Vector3 } = {}; - - function addPoint(p: THREE.Vector3) { - if (filter && !filter(p)) { - return; - } - const key = `${p.x},${p.y},${p.z}`; - if (!pointsHash[key]) { - pointsHash[key] = p; - } - } - - function recurse( - depth: number, - p0: THREE.Vector3, - p1: THREE.Vector3, - p2: THREE.Vector3, - ) { - addPoint(p0); - addPoint(p1); - addPoint(p2); - if (depth >= maxDepth) { - return; - } - const p01 = new THREE.Vector3().addVectors(p0, p1).normalize(); - const p12 = new THREE.Vector3().addVectors(p1, p2).normalize(); - const p20 = new THREE.Vector3().addVectors(p2, p0).normalize(); - recurse(depth + 1, p0, p01, p20); - recurse(depth + 1, p01, p1, p12); - recurse(depth + 1, p20, p12, p2); - recurse(depth + 1, p01, p12, p20); - } - - for (const x of [-1, 1]) { - for (const y of [-1, 1]) { - for (const z of [-1, 1]) { - const p0 = new THREE.Vector3(x, 0, 0); - const p1 = new THREE.Vector3(0, y, 0); - const p2 = new THREE.Vector3(0, 0, z); - recurse(0, p0, p1, p2); - } - } - } - - const points = Object.values(pointsHash); - const scales = new THREE.Vector3(pointRadius, pointRadius, pointThickness); - const quaternion = new THREE.Quaternion(); - const pointColor = typeof color === "function" ? new THREE.Color() : color; - for (const point of points) { - quaternion.setFromUnitVectors(new THREE.Vector3(0, 0, -1), point); - if (typeof color === "function") { - color(pointColor, point); - } - point.multiplyScalar(radius); - point.add(origin); - splats.pushSplat(point, scales, quaternion, 1.0, pointColor); - } -} - -export function textSplats({ - // text string to display - text, - // browser font to render text with (default: "Arial") - font, - // font size in pixels/Gsplats (default: 32) - fontSize, - // SplatMesh.recolor tint assuming white Gsplats (default: white) - color, - // Individual Gsplat color (default: white) - rgb, - // Gsplat radius (default: 0.8 covers 1-unit spacing well) - dotRadius, - // text alignment: "left", "center", "right", "start", "end" (default: "start") - textAlign, - // line spacing multiplier, lines delimited by "\n" (default: 1.0) - lineHeight, - // Coordinate scale in object-space (default: 1.0) - objectScale, -}: { - text: string; - font?: string; - fontSize?: number; - color?: THREE.Color; - rgb?: THREE.Color; - dotRadius?: number; - textAlign?: "left" | "center" | "right" | "start" | "end"; - lineHeight?: number; - objectScale?: number; -}) { - font = font ?? "Arial"; - fontSize = fontSize ?? 32; - color = color ?? new THREE.Color(1, 1, 1); - dotRadius = dotRadius ?? 0.8; - textAlign = textAlign ?? "start"; - lineHeight = lineHeight ?? 1; - objectScale = objectScale ?? 1; - const lines = text.split("\n"); - - const canvas = document.createElement("canvas"); - const ctx = canvas.getContext("2d"); - if (!ctx) { - throw new Error("Failed to create canvas context"); - } - - ctx.font = `${fontSize}px ${font}`; - ctx.textAlign = textAlign; - const metrics = ctx.measureText(""); - const fontHeight = - metrics.fontBoundingBoxAscent + metrics.fontBoundingBoxDescent; - - let minLeft = Number.POSITIVE_INFINITY; - let maxRight = Number.NEGATIVE_INFINITY; - let minTop = Number.POSITIVE_INFINITY; - let maxBottom = Number.NEGATIVE_INFINITY; - for (let line = 0; line < lines.length; ++line) { - const metrics = ctx.measureText(lines[line]); - const y = fontHeight * lineHeight * line; - minLeft = Math.min(minLeft, -metrics.actualBoundingBoxLeft); - maxRight = Math.max(maxRight, metrics.actualBoundingBoxRight); - minTop = Math.min(minTop, y - metrics.actualBoundingBoxAscent); - maxBottom = Math.max(maxBottom, y + metrics.actualBoundingBoxDescent); - } - const originLeft = Math.floor(minLeft); - const originTop = Math.floor(minTop); - const width = Math.ceil(maxRight) - originLeft; - const height = Math.ceil(maxBottom) - originTop; - canvas.width = width; - canvas.height = height; - - ctx.font = `${fontSize}px ${font}`; - ctx.textAlign = textAlign; - ctx.textBaseline = "alphabetic"; - ctx.fillStyle = "#FFFFFF"; - for (let i = 0; i < lines.length; ++i) { - const y = fontHeight * lineHeight * i - originTop; - ctx.fillText(lines[i], -originLeft, y); - } - - const imageData = ctx.getImageData(0, 0, width, height); - const rgba = new Uint8Array(imageData.data.buffer); - const splats = new PackedSplats(); - const center = new THREE.Vector3(); - const scales = new THREE.Vector3().setScalar(dotRadius * objectScale); - const quaternion = new THREE.Quaternion(0, 0, 0, 1); - rgb = rgb ?? new THREE.Color(1, 1, 1); - - let offset = 0; - for (let y = 0; y < height; ++y) { - for (let x = 0; x < width; ++x) { - const a = rgba[offset + 3]; - if (a > 0) { - const opacity = a / 255; - center.set(x - 0.5 * (width - 1), 0.5 * (height - 1) - y, 0); - center.multiplyScalar(objectScale); - splats.pushSplat(center, scales, quaternion, opacity, rgb); - } - offset += 4; - } - } - - const mesh = new SplatMesh({ packedSplats: splats }); - mesh.recolor = color; - return mesh; -} - -export function imageSplats({ - // URL of the image to convert to splats (example: `url: "./image.png"`) - url, - // Radius of each Gsplat, default covers 1-unit spacing well (default: 0.8) - dotRadius, - // Subsampling factor for the image. Higher values reduce resolution, - // for example 2 will halve the width and height by averaging (default: 1) - subXY, - // Optional callback function to modify each Gsplat before it's added. - // Return null to skip adding the Gsplat, or a number to set the opacity - // and add the Gsplat with parameter values in the objects center, rgba etc. were - // passed into the forEachSplat callback. Ending the callback in `return opacity;` - // will retain the original opacity. - // ((width: number, height: number, index: number, center: THREE.Vector3, scales: THREE.Vector3, quaternion: THREE.Quaternion, opacity: number, color: THREE.Color) => number | null) - forEachSplat, -}: { - url: string; - dotRadius?: number; - subXY?: number; - forEachSplat?: ( - width: number, - height: number, - index: number, - center: THREE.Vector3, - scales: THREE.Vector3, - quaternion: THREE.Quaternion, - opacity: number, - color: THREE.Color, - ) => number | null; -}): SplatMesh { - dotRadius = dotRadius ?? 0.8; - subXY = Math.max(1, Math.floor(subXY ?? 1)); - - return new SplatMesh({ - constructSplats: async (splats) => { - return new Promise((resolve, reject) => { - const img = new Image(); - img.crossOrigin = "anonymous"; - img.onerror = reject; - img.onload = () => { - const { width, height } = img; - const canvas = document.createElement("canvas"); - canvas.width = width; - canvas.height = height; - const ctx = canvas.getContext("2d"); - if (!ctx) { - reject(new Error("Failed to create canvas context")); - return; - } - ctx.imageSmoothingEnabled = true; - ctx.imageSmoothingQuality = "high"; - const destWidth = Math.round(width / subXY); - const destHeight = Math.round(height / subXY); - ctx.drawImage(img, 0, 0, destWidth, destHeight); - try { - const imageData = ctx.getImageData(0, 0, destWidth, destHeight); - const rgba = new Uint8Array(imageData.data.buffer); - - const center = new THREE.Vector3(); - const scales = new THREE.Vector3().setScalar(dotRadius); - const quaternion = new THREE.Quaternion(0, 0, 0, 1); - const rgb = new THREE.Color(); - - let index = 0; - for (let y = 0; y < destHeight; ++y) { - for (let x = 0; x < destWidth; ++x) { - const offset = index * 4; - const a = rgba[offset + 3]; - if (a > 0) { - let opacity = a / 255; - rgb.set( - rgba[offset + 0] / 255, - rgba[offset + 1] / 255, - rgba[offset + 2] / 255, - ); - center.set( - x - 0.5 * (destWidth - 1), - 0.5 * (destHeight - 1) - y, - 0, - ); - scales.setScalar(dotRadius); - quaternion.set(0, 0, 0, 1); - let push = true; - if (forEachSplat) { - const maybeOpacity = forEachSplat( - destWidth, - destHeight, - index, - center, - scales, - quaternion, - opacity, - rgb, - ); - opacity = maybeOpacity ?? opacity; - push = maybeOpacity !== null; - } - if (push) { - splats.pushSplat(center, scales, quaternion, opacity, rgb); - } - } - index += 1; - } - } - resolve(); - } catch (error) { - reject(error); - } - }; - img.src = url; - }); - }, - }); -} diff --git a/src/splatWorker.ts b/src/splatWorker.ts deleted file mode 100644 index 2c9a430..0000000 --- a/src/splatWorker.ts +++ /dev/null @@ -1,128 +0,0 @@ -import { getArrayBuffers } from "./utils.js"; -import BundledWorker from "./worker?worker&inline"; - -// SplatWorker is an internal class that manages a WebWorker for executing -// longer running CPU tasks such as Gsplat file decoding and sorting. -// Although a SplatWorker can be created and used directly, the utility -// function withWorker() is recommended to allocate from a managed -// pool of SplatWorkers. - -export class SplatWorker { - worker: Worker; - messages: Record< - number, - { resolve: (value: unknown) => void; reject: (reason?: unknown) => void } - > = {}; - messageIdNext = 0; - - constructor() { - // this.worker = new Worker(new URL("./worker", import.meta.url), { type: "module" }); - this.worker = new BundledWorker(); - this.worker.onmessage = (event) => this.onMessage(event); - } - - makeMessageId(): number { - return ++this.messageIdNext; - } - - makeMessagePromiseId(): { id: number; promise: Promise } { - const id = this.makeMessageId(); - const promise = new Promise((resolve, reject) => { - this.messages[id] = { resolve, reject }; - }); - return { id, promise }; - } - - onMessage(event: MessageEvent) { - // console.log("SplatWorker.onMessage:", event); - const { id, result, error } = event.data; - // console.log(`SplatWorker.onMessage(${id}):`, result, error); - const handler = this.messages[id]; - if (handler) { - delete this.messages[id]; - if (error) { - handler.reject(error); - } else { - handler.resolve(result); - } - } - } - - // Invoke an RPC on the worker with the given name and arguments. - // The normal usage of a worker is to run one activity at a time, - // but this function allows for concurrent calls, tagging each request - // with a unique message Id and awaiting a response to that same Id. - // The method will automatically transfer any ArrayBuffers in the - // arguments to the worker. If you'd like to transfer a copy of a - // buffer then you must clone it before passing to this function. - async call(name: string, args: unknown): Promise { - const { id, promise } = this.makeMessagePromiseId(); - // console.log(`SplatWorker.call(${name}):`, args); - this.worker.postMessage( - { name, args, id }, - { transfer: getArrayBuffers(args) }, - ); - return promise; - } -} - -let maxWorkers = 4; - -let numWorkers = 0; -const freeWorkers: SplatWorker[] = []; -const workerQueue: ((worker: SplatWorker) => void)[] = []; - -// Set the maximum number of workers to allocate for the pool. (default: 4) -export function setWorkerPool(count = 4) { - maxWorkers = count; -} - -// Allocate a worker from the pool. If none are available and we are below the -// maximum, create a new one. Otherwise, add the request to a queue and wait -// for it to be fulfilled. -export async function allocWorker(): Promise { - const worker = freeWorkers.shift(); - if (worker) { - return worker; - } - - if (numWorkers < maxWorkers) { - const worker = new SplatWorker(); - numWorkers += 1; - return worker; - } - - return new Promise((resolve) => { - workerQueue.push(resolve); - }); -} - -// Return a worker to the pool. Pass the worker to any pending waiter. -export function freeWorker(worker: SplatWorker) { - if (numWorkers > maxWorkers) { - // Worker no longer needed - numWorkers -= 1; - return; - } - - const waiter = workerQueue.shift(); - if (waiter) { - waiter(worker); - return; - } - - freeWorkers.push(worker); -} - -// Allocate a worker from the pool and invoke the callback with the worker. -// When the callback completes, the worker will be returned to the pool. -export async function withWorker( - callback: (worker: SplatWorker) => Promise, -): Promise { - const worker = await allocWorker(); - try { - return await callback(worker); - } finally { - freeWorker(worker); - } -} diff --git a/src/spz.ts b/src/spz.ts deleted file mode 100644 index 2e1a1f3..0000000 --- a/src/spz.ts +++ /dev/null @@ -1,833 +0,0 @@ -import * as THREE from "three"; -import { - SplatData, - SplatFileType, - type TranscodeSpzInput, - getSplatFileType, - getSplatFileTypeFromPath, -} from "./SplatLoader"; -import { GunzipReader, fromHalf, normalize, unpackSplat } from "./utils"; - -import { decodeAntiSplat } from "./antisplat"; -import { decodeKsplat } from "./ksplat"; -import { PlyReader } from "./ply"; - -// SPZ file format reader - -export class SpzReader { - fileBytes: Uint8Array; - reader: GunzipReader; - - version = -1; - numSplats = 0; - shDegree = 0; - fractionalBits = 0; - flags = 0; - flagAntiAlias = false; - reserved = 0; - headerParsed = false; - parsed = false; - - constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) { - this.fileBytes = - fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes; - this.reader = new GunzipReader({ fileBytes: this.fileBytes }); - } - - async parseHeader() { - if (this.headerParsed) { - throw new Error("SPZ file header already parsed"); - } - - const header = new DataView((await this.reader.read(16)).buffer); - if (header.getUint32(0, true) !== 0x5053474e) { - throw new Error("Invalid SPZ file"); - } - this.version = header.getUint32(4, true); - if (this.version < 1 || this.version > 3) { - throw new Error(`Unsupported SPZ version: ${this.version}`); - } - - this.numSplats = header.getUint32(8, true); - this.shDegree = header.getUint8(12); - this.fractionalBits = header.getUint8(13); - this.flags = header.getUint8(14); - this.flagAntiAlias = (this.flags & 0x01) !== 0; - this.reserved = header.getUint8(15); - this.headerParsed = true; - this.parsed = false; - } - - async parseSplats( - centerCallback?: (index: number, x: number, y: number, z: number) => void, - alphaCallback?: (index: number, alpha: number) => void, - rgbCallback?: (index: number, r: number, g: number, b: number) => void, - scalesCallback?: ( - index: number, - scaleX: number, - scaleY: number, - scaleZ: number, - ) => void, - quatCallback?: ( - index: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - ) => void, - shCallback?: ( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, - ) => void, - ) { - if (!this.headerParsed) { - throw new Error("SPZ file header must be parsed first"); - } - if (this.parsed) { - throw new Error("SPZ file already parsed"); - } - this.parsed = true; - - if (this.version === 1) { - // float16 centers - const centerBytes = await this.reader.read(this.numSplats * 3 * 2); - const centerUint16 = new Uint16Array(centerBytes.buffer); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const x = fromHalf(centerUint16[i3]); - const y = fromHalf(centerUint16[i3 + 1]); - const z = fromHalf(centerUint16[i3 + 2]); - centerCallback?.(i, x, y, z); - } - } else if (this.version === 2 || this.version === 3) { - // 24-bit fixed-point centers - const fixed = 1 << this.fractionalBits; - const centerBytes = await this.reader.read(this.numSplats * 3 * 3); - for (let i = 0; i < this.numSplats; i++) { - const i9 = i * 9; - const x = - (((centerBytes[i9 + 2] << 24) | - (centerBytes[i9 + 1] << 16) | - (centerBytes[i9] << 8)) >> - 8) / - fixed; - const y = - (((centerBytes[i9 + 5] << 24) | - (centerBytes[i9 + 4] << 16) | - (centerBytes[i9 + 3] << 8)) >> - 8) / - fixed; - const z = - (((centerBytes[i9 + 8] << 24) | - (centerBytes[i9 + 7] << 16) | - (centerBytes[i9 + 6] << 8)) >> - 8) / - fixed; - centerCallback?.(i, x, y, z); - } - } else { - throw new Error("Unreachable"); - } - - { - const bytes = await this.reader.read(this.numSplats); - for (let i = 0; i < this.numSplats; i++) { - alphaCallback?.(i, bytes[i] / 255); - } - } - { - const rgbBytes = await this.reader.read(this.numSplats * 3); - const scale = SH_C0 / 0.15; - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const r = (rgbBytes[i3] / 255 - 0.5) * scale + 0.5; - const g = (rgbBytes[i3 + 1] / 255 - 0.5) * scale + 0.5; - const b = (rgbBytes[i3 + 2] / 255 - 0.5) * scale + 0.5; - rgbCallback?.(i, r, g, b); - } - } - { - const scalesBytes = await this.reader.read(this.numSplats * 3); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const scaleX = Math.exp(scalesBytes[i3] / 16 - 10); - const scaleY = Math.exp(scalesBytes[i3 + 1] / 16 - 10); - const scaleZ = Math.exp(scalesBytes[i3 + 2] / 16 - 10); - scalesCallback?.(i, scaleX, scaleY, scaleZ); - } - } - if (this.version === 3) { - // Version 3 uses a trick called "smallest three" to compress the rotation quaternions - // achieving better precision. "Optimizing orientation" section at https://gafferongames.com/post/snapshot_compression/ A quaternion length must be 1: x^2+y^2+z^2+w^2 = 1 - // We can drop one component and reconstruct it with the identity above. - // Largest component is dropped for best numerical precision. - // Quaternion stored in 32 bits - // 10 bits singed integer for each of the 3 components + 2 bits indicating the index of dropped component. - // vs 8 bits for each component uncompressed (spz version < 3) - // Max Value after extracting largest component v is another component v - // (v,v,0,0) - // v^2 + v^2 = 1 - // v = 1 / sqrt(2); - const maxValue = 1 / Math.sqrt(2); // 0.7071 - const quatBytes = await this.reader.read(this.numSplats * 4); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 4; - const quaternion = [0, 0, 0, 0]; - const values = [ - quatBytes[i3], - quatBytes[i3 + 1], - quatBytes[i3 + 2], - quatBytes[i3 + 3], - ]; - // all values are packed in 32 bits (10 per each of 3 components + 2 bits of index of larged value) - const combinedValues = - values[0] + (values[1] << 8) + (values[2] << 16) + (values[3] << 24); - // each component value is 9 bits + sign (1 bit) - const valueMask = (1 << 9) - 1; - // extract index of the largest element. 2 top bits. - const largestIndex = combinedValues >>> 30; - let remainingValues = combinedValues; - let sumSquares = 0; - - for (let i = 3; i >= 0; --i) { - if (i !== largestIndex) { - // extract current value and sign. - const value = remainingValues & valueMask; - const sign = (remainingValues >>> 9) & 0x1; - // each value is represented as 10 bits. Shift to next one. - remainingValues = remainingValues >>> 10; - // convert to range [0,1] and then to [0, 0.7071] - quaternion[i] = maxValue * (value / valueMask); - // apply sign. - quaternion[i] = sign === 0 ? quaternion[i] : -quaternion[i]; - // accumulate the sum of squares - sumSquares += quaternion[i] * quaternion[i]; - } - } - - // quartenion length must be 1 (x^2+y^2+z^2+w^2 = 1) - // so can reconstruct largest component from the other 3. - // w = sqrt(1 - x^2 - y^2 - z^2); - const square = 1 - sumSquares; - quaternion[largestIndex] = Math.sqrt(Math.max(square, 0)); - - quatCallback?.( - i, - quaternion[0], - quaternion[1], - quaternion[2], - quaternion[3], - ); - } - } else { - const quatBytes = await this.reader.read(this.numSplats * 3); - for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 3; - const quatX = quatBytes[i3] / 127.5 - 1; - const quatY = quatBytes[i3 + 1] / 127.5 - 1; - const quatZ = quatBytes[i3 + 2] / 127.5 - 1; - const quatW = Math.sqrt( - Math.max(0, 1 - quatX * quatX - quatY * quatY - quatZ * quatZ), - ); - quatCallback?.(i, quatX, quatY, quatZ, quatW); - } - } - - if (shCallback && this.shDegree >= 1) { - const sh1 = new Float32Array(3 * 3); - const sh2 = this.shDegree >= 2 ? new Float32Array(5 * 3) : undefined; - const sh3 = this.shDegree >= 3 ? new Float32Array(7 * 3) : undefined; - const shBytes = await this.reader.read( - this.numSplats * SH_DEGREE_TO_VECS[this.shDegree] * 3, - ); - - let offset = 0; - for (let i = 0; i < this.numSplats; i++) { - for (let j = 0; j < 9; ++j) { - sh1[j] = (shBytes[offset + j] - 128) / 128; - } - offset += 9; - if (sh2) { - for (let j = 0; j < 15; ++j) { - sh2[j] = (shBytes[offset + j] - 128) / 128; - } - offset += 15; - } - if (sh3) { - for (let j = 0; j < 21; ++j) { - sh3[j] = (shBytes[offset + j] - 128) / 128; - } - offset += 21; - } - shCallback?.(i, sh1, sh2, sh3); - } - } - } -} - -const SH_DEGREE_TO_VECS: Record = { 1: 3, 2: 8, 3: 15 }; -const SH_C0 = 0.28209479177387814; - -export const SPZ_MAGIC = 0x5053474e; // NGSP = Niantic gaussian splat -export const SPZ_VERSION = 3; -export const FLAG_ANTIALIASED = 0x1; - -export class SpzWriter { - buffer: ArrayBuffer; - view: DataView; - numSplats: number; - shDegree: number; - fractionalBits: number; - fraction: number; - flagAntiAlias: boolean; - clippedCount = 0; - - constructor({ - numSplats, - shDegree, - fractionalBits = 12, - flagAntiAlias = true, - }: { - numSplats: number; - shDegree: number; - fractionalBits?: number; - flagAntiAlias?: boolean; - }) { - const splatSize = - 9 + // Position - 1 + // Opacity - 3 + // Scale - 3 + // DC-rgb - 4 + // Rotation - (shDegree >= 1 ? 9 : 0) + - (shDegree >= 2 ? 15 : 0) + - (shDegree >= 3 ? 21 : 0); - const bufferSize = 16 + numSplats * splatSize; - this.buffer = new ArrayBuffer(bufferSize); - this.view = new DataView(this.buffer); - - this.view.setUint32(0, SPZ_MAGIC, true); // NGSP - this.view.setUint32(4, SPZ_VERSION, true); - this.view.setUint32(8, numSplats, true); - this.view.setUint8(12, shDegree); - this.view.setUint8(13, fractionalBits); - this.view.setUint8(14, flagAntiAlias ? FLAG_ANTIALIASED : 0); - this.view.setUint8(15, 0); // Reserved - - this.numSplats = numSplats; - this.shDegree = shDegree; - this.fractionalBits = fractionalBits; - this.fraction = 1 << fractionalBits; - this.flagAntiAlias = flagAntiAlias; - } - - setCenter(index: number, x: number, y: number, z: number) { - // Divide by this.fraction and round to nearest integer, - // then write as 3-bytes per x then y then z. - const xRounded = Math.round(x * this.fraction); - const xInt = Math.max(-0x7fffff, Math.min(0x7fffff, xRounded)); - const yRounded = Math.round(y * this.fraction); - const yInt = Math.max(-0x7fffff, Math.min(0x7fffff, yRounded)); - const zRounded = Math.round(z * this.fraction); - const zInt = Math.max(-0x7fffff, Math.min(0x7fffff, zRounded)); - const clipped = xRounded !== xInt || yRounded !== yInt || zRounded !== zInt; - if (clipped) { - this.clippedCount += 1; - // if (this.clippedCount < 10) { - // // Write x y z also in hex - // console.log(`Clipped ${index}: ${x}, ${y}, ${z} (0x${x.toString(16)}, 0x${y.toString(16)}, 0x${z.toString(16)}) -> ${xRounded}, ${yRounded}, ${zRounded} (0x${xRounded.toString(16)}, 0x${yRounded.toString(16)}, 0x${zRounded.toString(16)}) -> ${xInt}, ${yInt}, ${zInt} (0x${xInt.toString(16)}, 0x${yInt.toString(16)}, 0x${zInt.toString(16)})`); - // } - } - const i9 = index * 9; - const base = 16 + i9; - this.view.setUint8(base, xInt & 0xff); - this.view.setUint8(base + 1, (xInt >> 8) & 0xff); - this.view.setUint8(base + 2, (xInt >> 16) & 0xff); - this.view.setUint8(base + 3, yInt & 0xff); - this.view.setUint8(base + 4, (yInt >> 8) & 0xff); - this.view.setUint8(base + 5, (yInt >> 16) & 0xff); - this.view.setUint8(base + 6, zInt & 0xff); - this.view.setUint8(base + 7, (zInt >> 8) & 0xff); - this.view.setUint8(base + 8, (zInt >> 16) & 0xff); - } - - setAlpha(index: number, alpha: number) { - const base = 16 + this.numSplats * 9 + index; - this.view.setUint8( - base, - Math.max(0, Math.min(255, Math.round(alpha * 255))), - ); - } - - static scaleRgb(r: number) { - const v = ((r - 0.5) / (SH_C0 / 0.15) + 0.5) * 255; - return Math.max(0, Math.min(255, Math.round(v))); - } - - setRgb(index: number, r: number, g: number, b: number) { - const base = 16 + this.numSplats * 10 + index * 3; - this.view.setUint8(base, SpzWriter.scaleRgb(r)); - this.view.setUint8(base + 1, SpzWriter.scaleRgb(g)); - this.view.setUint8(base + 2, SpzWriter.scaleRgb(b)); - } - - setScale(index: number, scaleX: number, scaleY: number, scaleZ: number) { - const base = 16 + this.numSplats * 13 + index * 3; - this.view.setUint8( - base, - Math.max(0, Math.min(255, Math.round((Math.log(scaleX) + 10) * 16))), - ); - this.view.setUint8( - base + 1, - Math.max(0, Math.min(255, Math.round((Math.log(scaleY) + 10) * 16))), - ); - this.view.setUint8( - base + 2, - Math.max(0, Math.min(255, Math.round((Math.log(scaleZ) + 10) * 16))), - ); - } - - setQuat( - index: number, - ...q: [number, number, number, number] // x, y, z, w - ) { - const base = 16 + this.numSplats * 16 + index * 4; - - const quat = normalize(q); - - // Find largest component - let iLargest = 0; - for (let i = 1; i < 4; ++i) { - if (Math.abs(quat[i]) > Math.abs(quat[iLargest])) { - iLargest = i; - } - } - - // Since -quat represents the same rotation as quat, transform the quaternion so the largest element - // is positive. This avoids having to send its sign bit. - const negate = quat[iLargest] < 0 ? 1 : 0; - - // Do compression using sign bit and 9-bit precision per element. - let comp = iLargest; - for (let i = 0; i < 4; ++i) { - if (i !== iLargest) { - const negbit = (quat[i] < 0 ? 1 : 0) ^ negate; - const mag = Math.floor( - ((1 << 9) - 1) * (Math.abs(quat[i]) / Math.SQRT1_2) + 0.5, - ); - comp = (comp << 10) | (negbit << 9) | mag; - } - } - - this.view.setUint8(base, comp & 0xff); - this.view.setUint8(base + 1, (comp >> 8) & 0xff); - this.view.setUint8(base + 2, (comp >> 16) & 0xff); - this.view.setUint8(base + 3, (comp >>> 24) & 0xff); - } - - static quantizeSh(sh: number, bits: number) { - const value = Math.round(sh * 128) + 128; - const bucketSize = 1 << (8 - bits); - const quantized = - Math.floor((value + bucketSize / 2) / bucketSize) * bucketSize; - return Math.max(0, Math.min(255, quantized)); - } - - setSh( - index: number, - sh1: Float32Array, - sh2?: Float32Array, - sh3?: Float32Array, - ) { - const shVecs = SH_DEGREE_TO_VECS[this.shDegree] || 0; - const base1 = 16 + this.numSplats * 20 + index * shVecs * 3; - for (let j = 0; j < 9; ++j) { - this.view.setUint8(base1 + j, SpzWriter.quantizeSh(sh1[j], 5)); - } - if (sh2) { - const base2 = base1 + 9; - for (let j = 0; j < 15; ++j) { - this.view.setUint8(base2 + j, SpzWriter.quantizeSh(sh2[j], 4)); - } - if (sh3) { - const base3 = base2 + 15; - for (let j = 0; j < 21; ++j) { - this.view.setUint8(base3 + j, SpzWriter.quantizeSh(sh3[j], 4)); - } - } - } - } - - async finalize(): Promise { - const input = new Uint8Array(this.buffer); - const stream = new ReadableStream({ - async start(controller) { - controller.enqueue(input); - controller.close(); - }, - }); - const compressed = stream.pipeThrough(new CompressionStream("gzip")); - const response = new Response(compressed); - const buffer = await response.arrayBuffer(); - console.log( - "Compressed", - input.length, - "bytes to", - buffer.byteLength, - "bytes", - ); - return new Uint8Array(buffer); - } -} - -export async function transcodeSpz(input: TranscodeSpzInput) { - const splats = new SplatData(); - const { - inputs, - clipXyz, - maxSh, - fractionalBits = 12, - opacityThreshold, - } = input; - for (const input of inputs) { - const scale = input.transform?.scale ?? 1; - const quaternion = new THREE.Quaternion().fromArray( - input.transform?.quaternion ?? [0, 0, 0, 1], - ); - const translate = new THREE.Vector3().fromArray( - input.transform?.translate ?? [0, 0, 0], - ); - const clip = clipXyz - ? new THREE.Box3( - new THREE.Vector3().fromArray(clipXyz.min), - new THREE.Vector3().fromArray(clipXyz.max), - ) - : undefined; - - function transformPos(pos: THREE.Vector3) { - pos.multiplyScalar(scale); - pos.applyQuaternion(quaternion); - pos.add(translate); - return pos; - } - - function transformScales(scales: THREE.Vector3) { - scales.multiplyScalar(scale); - return scales; - } - - function transformQuaternion(quat: THREE.Quaternion) { - quat.premultiply(quaternion); - return quat; - } - - function withinClip(p: THREE.Vector3) { - return !clip || clip.containsPoint(p); - } - - function withinOpacity(opacity: number) { - return opacityThreshold !== undefined - ? opacity >= opacityThreshold - : true; - } - - let fileType = input.fileType; - if (!fileType) { - fileType = getSplatFileType(input.fileBytes); - if (!fileType && input.pathOrUrl) { - fileType = getSplatFileTypeFromPath(input.pathOrUrl); - } - } - switch (fileType) { - case SplatFileType.PLY: { - const ply = new PlyReader({ fileBytes: input.fileBytes }); - await ply.parseHeader(); - let lastIndex: number | null = null; - ply.parseSplats( - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - if (withinClip(center) && withinOpacity(opacity)) { - lastIndex = splats.pushSplat(); - splats.setCenter(lastIndex, center.x, center.y, center.z); - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(lastIndex, scales.x, scales.y, scales.z); - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - lastIndex, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - splats.setOpacity(lastIndex, opacity); - splats.setColor(lastIndex, r, g, b); - } else { - lastIndex = null; - } - }, - (index, sh1, sh2, sh3) => { - if (sh1 && lastIndex !== null) { - splats.setSh1(lastIndex, sh1); - } - if (sh2 && lastIndex !== null) { - splats.setSh2(lastIndex, sh2); - } - if (sh3 && lastIndex !== null) { - splats.setSh3(lastIndex, sh3); - } - }, - ); - break; - } - case SplatFileType.SPZ: { - const spz = new SpzReader({ fileBytes: input.fileBytes }); - await spz.parseHeader(); - const mapping = new Int32Array(spz.numSplats); - mapping.fill(-1); - const centers = new Float32Array(spz.numSplats * 3); - const center = new THREE.Vector3(); - spz.parseSplats( - (index, x, y, z) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - centers[index * 3] = center.x; - centers[index * 3 + 1] = center.y; - centers[index * 3 + 2] = center.z; - }, - (index, alpha) => { - center.fromArray(centers, index * 3); - if (withinClip(center) && withinOpacity(alpha)) { - mapping[index] = splats.pushSplat(); - splats.setCenter(mapping[index], center.x, center.y, center.z); - splats.setOpacity(mapping[index], alpha); - } - }, - (index, r, g, b) => { - if (mapping[index] >= 0) { - splats.setColor(mapping[index], r, g, b); - } - }, - (index, scaleX, scaleY, scaleZ) => { - if (mapping[index] >= 0) { - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(mapping[index], scales.x, scales.y, scales.z); - } - }, - (index, quatX, quatY, quatZ, quatW) => { - if (mapping[index] >= 0) { - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - mapping[index], - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - } - }, - (index, sh1, sh2, sh3) => { - if (mapping[index] >= 0) { - splats.setSh1(mapping[index], sh1); - if (sh2) { - splats.setSh2(mapping[index], sh2); - } - if (sh3) { - splats.setSh3(mapping[index], sh3); - } - } - }, - ); - break; - } - case SplatFileType.SPLAT: - decodeAntiSplat( - input.fileBytes, - (numSplats) => {}, - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - if (withinClip(center) && withinOpacity(opacity)) { - const index = splats.pushSplat(); - splats.setCenter(index, center.x, center.y, center.z); - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(index, scales.x, scales.y, scales.z); - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - index, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - splats.setOpacity(index, opacity); - splats.setColor(index, r, g, b); - } - }, - ); - break; - case SplatFileType.KSPLAT: { - let lastIndex: number | null = null; - decodeKsplat( - input.fileBytes, - (numSplats) => {}, - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - const center = transformPos(new THREE.Vector3(x, y, z)); - if (withinClip(center) && withinOpacity(opacity)) { - lastIndex = splats.pushSplat(); - splats.setCenter(lastIndex, center.x, center.y, center.z); - const scales = transformScales( - new THREE.Vector3(scaleX, scaleY, scaleZ), - ); - splats.setScale(lastIndex, scales.x, scales.y, scales.z); - const quaternion = transformQuaternion( - new THREE.Quaternion(quatX, quatY, quatZ, quatW), - ); - splats.setQuaternion( - lastIndex, - quaternion.x, - quaternion.y, - quaternion.z, - quaternion.w, - ); - splats.setOpacity(lastIndex, opacity); - splats.setColor(lastIndex, r, g, b); - } else { - lastIndex = null; - } - }, - (index, sh1, sh2, sh3) => { - if (lastIndex !== null) { - splats.setSh1(lastIndex, sh1); - if (sh2) { - splats.setSh2(lastIndex, sh2); - } - if (sh3) { - splats.setSh3(lastIndex, sh3); - } - } - }, - ); - break; - } - default: - throw new Error(`transcodeSpz not implemented for ${fileType}`); - } - } - - const shDegree = Math.min( - maxSh ?? 3, - splats.sh3 ? 3 : splats.sh2 ? 2 : splats.sh1 ? 1 : 0, - ); - const spz = new SpzWriter({ - numSplats: splats.numSplats, - shDegree, - fractionalBits, - flagAntiAlias: true, - }); - - for (let i = 0; i < splats.numSplats; ++i) { - const i3 = i * 3; - const i4 = i * 4; - spz.setCenter( - i, - splats.centers[i3], - splats.centers[i3 + 1], - splats.centers[i3 + 2], - ); - spz.setScale( - i, - splats.scales[i3], - splats.scales[i3 + 1], - splats.scales[i3 + 2], - ); - spz.setQuat( - i, - splats.quaternions[i4], - splats.quaternions[i4 + 1], - splats.quaternions[i4 + 2], - splats.quaternions[i4 + 3], - ); - spz.setAlpha(i, splats.opacities[i]); - spz.setRgb( - i, - splats.colors[i3], - splats.colors[i3 + 1], - splats.colors[i3 + 2], - ); - if (splats.sh1 && shDegree >= 1) { - spz.setSh( - i, - splats.sh1.slice(i * 9, (i + 1) * 9), - shDegree >= 2 && splats.sh2 - ? splats.sh2.slice(i * 15, (i + 1) * 15) - : undefined, - shDegree >= 3 && splats.sh3 - ? splats.sh3.slice(i * 21, (i + 1) * 21) - : undefined, - ); - } - } - - const spzBytes = await spz.finalize(); - return { fileBytes: spzBytes, clippedCount: spz.clippedCount }; -} diff --git a/src/transcode.ts b/src/transcode.ts new file mode 100644 index 0000000..b5e9455 --- /dev/null +++ b/src/transcode.ts @@ -0,0 +1,331 @@ +import * as THREE from "three"; +import { SplatFileType } from "./SplatLoader"; +import { getSplatFileType, getSplatFileTypeFromPath } from "./SplatLoader"; +import { SH_DEGREE_TO_NUM_COEFF } from "./defines"; +import type { SplatEncoder } from "./encoding/encoder"; +import { unpackAntiSplat } from "./formats/antisplat"; +import { unpackKsplat } from "./formats/ksplat"; +import { unpackPcSogsZip } from "./formats/pcsogs"; +import { unpackPly } from "./formats/ply"; +import { SpzWriter, unpackSpz } from "./formats/spz"; + +export type FileInput = { + fileBytes: Uint8Array; + fileType?: SplatFileType; + pathOrUrl?: string; + transform?: { translate?: number[]; quaternion?: number[]; scale?: number }; +}; + +export type TranscodeSpzInput = { + /** + * Collection of input files to transcode. + * Each file can have an optional transform to apply. + */ + inputs: FileInput[]; + /** + * The maximum number of spherical harmonics. + */ + maxSh?: number; + /** + * Optional clip box. Any splats outside of this box after + * apply transformations will be omitted from the output. + */ + clipXyz?: { min: number[]; max: number[] }; + /** + * Number of fractional bits to use. + */ + fractionalBits?: number; + /** + * Optional threshold to filter out splats with opacities below + * this value. + */ + opacityThreshold?: number; +}; + +const MAX_SPLATS = 50_000_000; +const tempV3 = new THREE.Vector3(); +const tempQuat = new THREE.Quaternion(); + +export async function transcodeSpz(input: TranscodeSpzInput) { + const { + inputs, + clipXyz, + maxSh = 3, + fractionalBits = 12, + opacityThreshold, + } = input; + + const numShCoefficients = SH_DEGREE_TO_NUM_COEFF[maxSh]; + const context = { + centers: new Float32Array( + new ArrayBuffer(0, { maxByteLength: MAX_SPLATS * 3 * 4 }), + ), + scales: new Float32Array( + new ArrayBuffer(0, { maxByteLength: MAX_SPLATS * 3 * 4 }), + ), + quats: new Float32Array( + new ArrayBuffer(0, { maxByteLength: MAX_SPLATS * 4 * 4 }), + ), + rgb: new Float32Array( + new ArrayBuffer(0, { maxByteLength: MAX_SPLATS * 3 * 4 }), + ), + opacities: new Float32Array( + new ArrayBuffer(0, { maxByteLength: MAX_SPLATS * 1 * 4 }), + ), + sh: new Float32Array( + new ArrayBuffer(0, { maxByteLength: MAX_SPLATS * numShCoefficients * 4 }), + ), + + head: 0, + indexMapping: {} as Record, + clippedIndices: new Set(), + capacity: 0, + + currentShBands: 0, + translate: new THREE.Vector3(), + rotate: new THREE.Quaternion(), + scale: 1, + clipBox: null as THREE.Box3 | null, + + getSplatIndex(index: number) { + if (!(index in context.indexMapping)) { + context.indexMapping[index] = this.head++; + } + return context.indexMapping[index]; + }, + + transformPos(pos: THREE.Vector3) { + pos.multiplyScalar(this.scale); + pos.applyQuaternion(this.rotate); + pos.add(this.translate); + return pos; + }, + + transformScales(scales: THREE.Vector3) { + scales.multiplyScalar(this.scale); + return scales; + }, + + transformQuaternion(quat: THREE.Quaternion) { + quat.premultiply(this.rotate); + return quat; + }, + + withinClip(p: THREE.Vector3) { + return !this.clipBox || this.clipBox.containsPoint(p); + }, + + withinOpacity(opacity: number) { + return opacityThreshold !== undefined + ? opacity >= opacityThreshold + : true; + }, + }; + + const splatEncoder: SplatEncoder = { + allocate(numSplats, numShBands) { + // Start of a new input, expand the buffers to accommodate at least numSplats + const remainingCapacity = context.capacity - context.head; + if (remainingCapacity < numSplats) { + const newCapacity = numSplats - remainingCapacity; + context.centers.buffer.resize(newCapacity * 3 * 4); + context.scales.buffer.resize(newCapacity * 3 * 4); + context.quats.buffer.resize(newCapacity * 4 * 4); + context.rgb.buffer.resize(newCapacity * 3 * 4); + context.opacities.buffer.resize(newCapacity * 1 * 4); + if (maxSh > 0) { + context.sh.buffer.resize(newCapacity * numShCoefficients * 4); + } + + context.capacity = newCapacity; + + // Keep track of the number of sh bands in the current input + context.currentShBands = numShBands; + } + }, + + setSplat( + i, + x, + y, + z, + scaleX, + scaleY, + scaleZ, + quatX, + quatY, + quatZ, + quatW, + opacity, + r, + g, + b, + ) { + this.setSplatCenter(i, x, y, z); + this.setSplatScales(i, scaleX, scaleY, scaleZ); + this.setSplatQuat(i, quatX, quatY, quatZ, quatW); + this.setSplatRgba(i, r, g, b, opacity); + }, + + setSplatAlpha(i, a) { + const index = context.getSplatIndex(i); + if (!context.withinOpacity(a)) { + context.clippedIndices.add(index); + } + context.opacities[index] = a; + }, + + setSplatCenter(i, x, y, z) { + const index = context.getSplatIndex(i); + const center = context.transformPos(tempV3.set(x, y, z)); + if (!context.withinClip(center)) { + context.clippedIndices.add(index); + } + context.centers[index * 3 + 0] = center.x; + context.centers[index * 3 + 1] = center.y; + context.centers[index * 3 + 2] = center.z; + }, + + setSplatQuat(i, quatX, quatY, quatZ, quatW) { + const index = context.getSplatIndex(i); + const quat = context.transformQuaternion( + tempQuat.set(quatX, quatY, quatZ, quatW), + ); + context.quats[index * 4 + 0] = quat.x; + context.quats[index * 4 + 1] = quat.y; + context.quats[index * 4 + 2] = quat.z; + context.quats[index * 4 + 3] = quat.w; + }, + + setSplatRgb(i, r, g, b) { + const index = context.getSplatIndex(i); + context.rgb[index * 3 + 0] = r; + context.rgb[index * 3 + 1] = g; + context.rgb[index * 3 + 2] = b; + }, + + setSplatRgba(i, r, g, b, a) { + this.setSplatRgb(i, r, g, b); + this.setSplatAlpha(i, a); + }, + + setSplatScales(i, scaleX, scaleY, scaleZ) { + const index = context.getSplatIndex(i); + const scales = context.transformScales( + tempV3.set(scaleX, scaleY, scaleZ), + ); + context.scales[index * 3 + 0] = scales.x; + context.scales[index * 3 + 1] = scales.y; + context.scales[index * 3 + 2] = scales.z; + }, + + setSplatSh(i, sh) { + const index = context.getSplatIndex(i); + const shStride = numShCoefficients; + const effectiveShBands = Math.min(context.currentShBands, maxSh); + const shCoefficients = SH_DEGREE_TO_NUM_COEFF[effectiveShBands]; + + for (let j = 0; j < shCoefficients; j++) { + context.sh[index * shStride + j] = sh[j]; + } + }, + + closeTransferable() {}, + + close() { + throw new Error("Not supported"); + }, + }; + + context.clipBox = clipXyz + ? new THREE.Box3( + new THREE.Vector3().fromArray(clipXyz.min), + new THREE.Vector3().fromArray(clipXyz.max), + ) + : null; + for (const input of inputs) { + context.translate.fromArray(input.transform?.translate ?? [0, 0, 0]); + context.rotate.fromArray(input.transform?.quaternion ?? [0, 0, 0, 1]); + context.scale = input.transform?.scale ?? 1; + + let fileType = input.fileType; + if (!fileType) { + fileType = getSplatFileType(input.fileBytes); + if (!fileType && input.pathOrUrl) { + fileType = getSplatFileTypeFromPath(input.pathOrUrl); + } + } + + const fileBytes = input.fileBytes; + switch (fileType) { + case SplatFileType.PLY: + await unpackPly(fileBytes, splatEncoder); + break; + case SplatFileType.SPZ: + await unpackSpz(fileBytes, splatEncoder); + break; + case SplatFileType.SPLAT: + unpackAntiSplat(fileBytes, splatEncoder); + break; + case SplatFileType.KSPLAT: + unpackKsplat(fileBytes, splatEncoder); + break; + case SplatFileType.PCSOGSZIP: + await unpackPcSogsZip(fileBytes, splatEncoder); + break; + default: + throw new Error(`transcodeSpz not implemented for: ${fileType}`); + } + } + + const numSplats = context.head - context.clippedIndices.size; + const spz = new SpzWriter({ + numSplats, + shDegree: maxSh, + fractionalBits, + flagAntiAlias: true, + }); + let i = 0; + // Go over all collected splats + for (let splat = 0; splat < context.head; splat++) { + // Skip splats that have been clipped + if (context.clippedIndices.has(splat)) { + continue; + } + + // Write splat properties to the SpzWriter + const i3 = i * 3; + const i4 = i * 4; + spz.setCenter( + i, + context.centers[i3], + context.centers[i3 + 1], + context.centers[i3 + 2], + ); + spz.setScale( + i, + context.scales[i3], + context.scales[i3 + 1], + context.scales[i3 + 2], + ); + spz.setQuat( + i, + context.quats[i4], + context.quats[i4 + 1], + context.quats[i4 + 2], + context.quats[i4 + 3], + ); + spz.setRgb(i, context.rgb[i3], context.rgb[i3 + 1], context.rgb[i3 + 2]); + spz.setAlpha(i, context.opacities[i]); + + if (numShCoefficients > 0) { + const shStride = numShCoefficients; + spz.setSh(i, context.sh.slice(shStride * i, shStride * (i + 1))); + } + + i++; + } + + const spzBytes = await spz.finalize(); + return { fileBytes: spzBytes, clippedCount: spz.clippedCount }; +} diff --git a/src/utils.ts b/src/utils.ts index 9dc7144..adff045 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,17 +1,11 @@ import { Gunzip } from "fflate"; import * as THREE from "three"; -// Miscellaneous utility functions for Spark - import { - LN_SCALE_MAX, - LN_SCALE_MIN, - SCALE_ZERO, SPLAT_TEX_HEIGHT, SPLAT_TEX_MIN_HEIGHT, SPLAT_TEX_WIDTH, } from "./defines.js"; -import { unindent } from "./dyno/base.js"; const f32buffer = new Float32Array(1); const u32buffer = new Uint32Array(f32buffer.buffer); @@ -21,34 +15,44 @@ const f16buffer = supportsFloat16Array : null; const u16buffer = new Uint16Array(f16buffer?.buffer); -// Returns a normalized array of numbers -export function normalize(vec: number[]) { - const norm = Math.sqrt(vec.reduce((acc, v) => acc + v * v, 0)); - return vec.map((v) => v / norm); -} - -// Reinterpret the bits of a float32 as a uint32 +/** + * Reinterpret the bits of a float32 as a uint32 + * @param f The float32 to reinterpret + * @returns The resulting uint32 + */ export function floatBitsToUint(f: number): number { f32buffer[0] = f; return u32buffer[0]; } -// Reinterpret the bits of a uint32 as a float32 +/** + * Reinterpret the bits of a uint32 as a float32 + * @param u The uint32 to reinterpret + * @returns The resulting float32 + */ export function uintBitsToFloat(u: number): number { u32buffer[0] = u; return f32buffer[0]; } +/** + * Reinterpret the bits of a float16 as a uint16 + * @param f The float16 to reinterpret + * @returns The resulting uint16 + */ export const toHalf = supportsFloat16Array ? toHalfNative : toHalfJS; +/** + * Reinterpret the bits of a uint16 as a float16 + * @param u The uint16 to reinterpret + * @returns The resulting float16 + */ export const fromHalf = supportsFloat16Array ? fromHalfNative : fromHalfJS; -// Encode a number as a float16, stored as a uint16 number. function toHalfNative(f: number): number { f16buffer[0] = f; return u16buffer[0]; } -// Encode a number as a float16, stored as a uint16 number. function toHalfJS(f: number): number { // Store the value into the shared Float32 array. f32buffer[0] = f; @@ -94,13 +98,11 @@ function toHalfJS(f: number): number { return halfSign | (newExp << 10) | halfFrac; } -// Convert a float16 stored as a uint16 number back to a float32. function fromHalfNative(u: number): number { u16buffer[0] = u; return f16buffer[0]; } -// Convert a float16 stored as a uint16 number back to a float32. function fromHalfJS(h: number): number { // Extract the sign (1 bit), exponent (5 bits), and fraction (10 bits) const sign = (h >> 15) & 0x1; @@ -154,105 +156,16 @@ function fromHalfJS(h: number): number { return f32buffer[0]; } -// Convert a number 0..1 to a 0..255 uint +/** + * Convert a float from 0..1 to a 0..255 uint + * @param v The number to convert + * @returns Uint8 representation + */ export function floatToUint8(v: number): number { // Converts from 0..1 float to 0..255 uint8 return Math.max(0, Math.min(255, Math.round(v * 255))); } -// Convert a number -1..1 to a -127..127 int -export function floatToSint8(v: number): number { - // Converts from -1..1 float to -127..127 int8 - return Math.max(-127, Math.min(127, Math.round(v * 127))); -} - -// Convert a 0..255 uint to a 0..1 float -export function Uint8ToFloat(v: number): number { - // Converts from 0..255 uint8 to 0..1 float - return v / 255; -} - -// Convert a -127..127 int to a -1..1 float -export function Sint8ToFloat(v: number): number { - // Converts from -127..127 int8 to -1..1 float - return v / 127; -} - -// A simple utility class for caching a fixed number of items -export class DataCache { - // Maximum number of items to cache - maxItems: number; - - // Function to fetch data for a key - asyncFetch: (key: string) => Promise; - - // Array of cached items - items: { key: string; data: unknown }[]; - - // Create a DataCache with a given function that fetches data not in the cache. - constructor({ - asyncFetch, - maxItems = 5, - }: { asyncFetch: (key: string) => Promise; maxItems?: number }) { - this.asyncFetch = asyncFetch; - this.maxItems = maxItems; - this.items = []; - } - - // Fetch data for the key, returning cached data if available. - async getFetch(key: string): Promise { - // Fetches data for a key and caches it, returns cached data if available. - const index = this.items.findIndex((item) => item.key === key); - if (index >= 0) { - // Data exists in our cache, move it to the end of the array - const item = this.items.splice(index, 1)[0]; - this.items.push(item); - // Return the cached data - return item.data; - } - - // Fetch the data from the asyncFetch function - const data = await this.asyncFetch(key); - // Add the data to the cache - this.items.push({ key, data }); - // If the cache is too large, remove the oldest accessed item - while (this.items.length > this.maxItems) { - this.items.shift(); - } - // Return the fetched data - return data; - } -} - -// Like Array.map but for objects -export function mapObject( - obj: Record, - fn: (value: unknown, key: string) => unknown, -): Record { - // Maps over an object, applying a function to each value and key - const entries = Object.entries(obj).map(([key, value]) => [ - key, - fn(value, key), - ]); - // Returns a new object with the mapped values - return Object.fromEntries(entries); -} - -// Like Array.map().filter() but for objects. -// The callback fn() should return undefined to filter out the key. -export function mapFilterObject( - obj: Record, - fn: (value: unknown, key: string) => unknown, -): Record { - // Maps over an object, applying a function to each value and key - // If no return (or return undefined), the key is not included in the result - const entries = Object.entries(obj) - .map(([key, value]) => [key, fn(value, key)]) - .filter(([_, value]) => value !== undefined); - // Returns a new object with the filtered values - return Object.fromEntries(entries); -} - // Recursively finds all ArrayBuffers in an object and returns them as an array // to use as transferable objects to send between workers. export function getArrayBuffers(ctx: unknown): Transferable[] { @@ -267,7 +180,7 @@ export function getArrayBuffers(ctx: unknown): Transferable[] { buffers.push(obj); } else if (ArrayBuffer.isView(obj)) { // Handles TypedArrays and DataView - buffers.push(obj.buffer); + buffers.push(obj.buffer as ArrayBuffer); } else if (Array.isArray(obj)) { obj.forEach(traverse); } else { @@ -280,415 +193,6 @@ export function getArrayBuffers(ctx: unknown): Transferable[] { return buffers; } -// Create an array of the given size and initialize element with initFunction() -export function newArray( - n: number, - initFunction: (index: number) => T, -): T[] { - // Creates a new array and calls a constructor function for each element with index - return new Array(n).fill(null).map((_, i) => initFunction(i)); -} - -// A free list that has a pool of items of type T, with callbacks -// for constructing, disposing, and checking if an item is valid for the given args. -export class FreeList { - items: T[]; - allocate: (args: Args) => T; - dispose?: (item: T) => void; - valid: (item: T, args: Args) => boolean; - - constructor({ - // Allocate a new item with the given args - allocate, - // Dispose of an item (optional, if GC is enough) - dispose, - // Check if an existing item in the list is valid for the given args, - // allowing you to store heterogeneous items in the list. - valid, - }: { - allocate: (args: Args) => T; - dispose?: (item: T) => void; - valid: (item: T, args: Args) => boolean; - }) { - this.items = []; - this.allocate = allocate; - this.dispose = dispose; - this.valid = valid; - } - - // Allocate a new item from the free list, first checking if a existing item - // on the freelist is valid for the given args. - alloc(args: Args): T { - while (true) { - const item = this.items.pop(); - if (!item) { - // No items in the free list, allocate a new one - break; - } - if (this.valid(item, args)) { - // Found a valid item, return it - // console.log(`FreeList.alloc(${JSON.stringify(args)}): found valid item. Reusing...`); - return item; - } - // Item isn't valid for our args, dispose of it and try again - if (this.dispose) { - // console.log(`FreeList.alloc(${JSON.stringify(args)}): disposing invalid item.`); - this.dispose(item); - } - } - // console.log(`FreeList.alloc(${JSON.stringify(args)}): allocating new item`); - return this.allocate(args); - } - - free(item: T) { - // Return item to the free list - this.items.push(item); - } - - disposeAll() { - // Disposes of all items in the free list - let item: T | undefined; - item = this.items.pop(); - while (item) { - if (this.dispose) { - this.dispose(item); - } - item = this.items.pop(); - } - } -} - -// Encode a PackedSplat as 4 consecutive Uint32 elements in the packedSplats array. -// The center coordinates x,y,z are encoded as float16, the scales x,y,z as a -// logarithmic uint8, rotation as three uint8s representing rotation axis and angle, -// and RGBA as 4xuint8. -export function setPackedSplat( - packedSplats: Uint32Array, - index: number, - x: number, - y: number, - z: number, - scaleX: number, - scaleY: number, - scaleZ: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, - opacity: number, - r: number, - g: number, - b: number, - encoding?: { - rgbMin?: number; - rgbMax?: number; - lnScaleMin?: number; - lnScaleMax?: number; - }, -) { - const rgbMin = encoding?.rgbMin ?? 0.0; - const rgbMax = encoding?.rgbMax ?? 1.0; - const rgbRange = rgbMax - rgbMin; - const uR = floatToUint8((r - rgbMin) / rgbRange); - const uG = floatToUint8((g - rgbMin) / rgbRange); - const uB = floatToUint8((b - rgbMin) / rgbRange); - const uA = floatToUint8(opacity); - - // Alternate internal encodings commented out below. - const uQuat = encodeQuatOctXy88R8( - tempQuaternion.set(quatX, quatY, quatZ, quatW), - ); - // const uQuat = encodeQuatXyz888(new THREE.Quaternion(quatX, quatY, quatZ, quatW)); - // const uQuat = encodeQuatEulerXyz888(new THREE.Quaternion(quatX, quatY, quatZ, quatW)); - const uQuatX = uQuat & 0xff; - const uQuatY = (uQuat >>> 8) & 0xff; - const uQuatZ = (uQuat >>> 16) & 0xff; - - // Allow scales below LN_SCALE_MIN to be encoded as 0, which signifies a 2DGS - const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; - const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; - const lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); - const uScaleX = - scaleX < SCALE_ZERO - ? 0 - : Math.min( - 255, - Math.max( - 1, - Math.round((Math.log(scaleX) - lnScaleMin) * lnScaleScale) + 1, - ), - ); - const uScaleY = - scaleY < SCALE_ZERO - ? 0 - : Math.min( - 255, - Math.max( - 1, - Math.round((Math.log(scaleY) - lnScaleMin) * lnScaleScale) + 1, - ), - ); - const uScaleZ = - scaleZ < SCALE_ZERO - ? 0 - : Math.min( - 255, - Math.max( - 1, - Math.round((Math.log(scaleZ) - lnScaleMin) * lnScaleScale) + 1, - ), - ); - - const uCenterX = toHalf(x); - const uCenterY = toHalf(y); - const uCenterZ = toHalf(z); - - // Encode the splat as 4 consecutive Uint32 elements - const i4 = index * 4; - packedSplats[i4] = uR | (uG << 8) | (uB << 16) | (uA << 24); - packedSplats[i4 + 1] = uCenterX | (uCenterY << 16); - packedSplats[i4 + 2] = uCenterZ | (uQuatX << 16) | (uQuatY << 24); - packedSplats[i4 + 3] = - uScaleX | (uScaleY << 8) | (uScaleZ << 16) | (uQuatZ << 24); -} - -// Encode the center coordinates x,y,z in the packedSplats Uint32Array, -// leaving all other fields as is. -export function setPackedSplatCenter( - packedSplats: Uint32Array, - index: number, - x: number, - y: number, - z: number, -) { - const uCenterX = toHalf(x); - const uCenterY = toHalf(y); - const uCenterZ = toHalf(z); - - const i4 = index * 4; - packedSplats[i4 + 1] = uCenterX | (uCenterY << 16); - packedSplats[i4 + 2] = uCenterZ | (packedSplats[i4 + 2] & 0xffff0000); -} - -// Encode the scales x,y,z in the packedSplats Uint32Array, leaving all other fields as is. -export function setPackedSplatScales( - packedSplats: Uint32Array, - index: number, - scaleX: number, - scaleY: number, - scaleZ: number, - encoding?: { - lnScaleMin?: number; - lnScaleMax?: number; - }, -) { - // Allow scales below LN_SCALE_MIN to be encoded as 0, which signifies a 2DGS - const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; - const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; - const lnScaleScale = 254.0 / (lnScaleMax - lnScaleMin); - const uScaleX = - scaleX < SCALE_ZERO - ? 0 - : Math.min( - 255, - Math.max( - 1, - Math.round((Math.log(scaleX) - lnScaleMin) * lnScaleScale) + 1, - ), - ); - const uScaleY = - scaleY < SCALE_ZERO - ? 0 - : Math.min( - 255, - Math.max( - 1, - Math.round((Math.log(scaleY) - lnScaleMin) * lnScaleScale) + 1, - ), - ); - const uScaleZ = - scaleZ < SCALE_ZERO - ? 0 - : Math.min( - 255, - Math.max( - 1, - Math.round((Math.log(scaleZ) - lnScaleMin) * lnScaleScale) + 1, - ), - ); - - const i4 = index * 4; - packedSplats[i4 + 3] = - uScaleX | - (uScaleY << 8) | - (uScaleZ << 16) | - (packedSplats[i4 + 3] & 0xff000000); -} - -// Temporary storage used in `encodeQuatOCtXy88R8` and `decodeQuatOctXy88R8` to -// avoid allocation new Quaternions and Vector3 instances. -const tempQuaternion = new THREE.Quaternion(); - -// Encode the rotation quatX, quatY, quatZ, quatW in the packedSplats Uint32Array, -// leaving all other fields as is. -export function setPackedSplatQuat( - packedSplats: Uint32Array, - index: number, - quatX: number, - quatY: number, - quatZ: number, - quatW: number, -) { - const uQuat = encodeQuatOctXy88R8( - tempQuaternion.set(quatX, quatY, quatZ, quatW), - ); - // const uQuat = encodeQuatXyz888(new THREE.Quaternion(quatX, quatY, quatZ, quatW)); - // const uQuat = encodeQuatEulerXyz888(new THREE.Quaternion(quatX, quatY, quatZ, quatW)); - const uQuatX = uQuat & 0xff; - const uQuatY = (uQuat >>> 8) & 0xff; - const uQuatZ = (uQuat >>> 16) & 0xff; - - const i4 = index * 4; - packedSplats[i4 + 2] = - (packedSplats[i4 + 2] & 0x0000ffff) | (uQuatX << 16) | (uQuatY << 24); - packedSplats[i4 + 3] = (packedSplats[i4 + 3] & 0x00ffffff) | (uQuatZ << 24); -} - -// Encode the RGBA color in the packedSplats Uint32Array, leaving other fields alone. -export function setPackedSplatRgba( - packedSplats: Uint32Array, - index: number, - r: number, - g: number, - b: number, - a: number, - encoding?: { - rgbMin?: number; - rgbMax?: number; - }, -) { - const rgbMin = encoding?.rgbMin ?? 0.0; - const rgbMax = encoding?.rgbMax ?? 1.0; - const rgbRange = rgbMax - rgbMin; - const uR = floatToUint8((r - rgbMin) / rgbRange); - const uG = floatToUint8((g - rgbMin) / rgbRange); - const uB = floatToUint8((b - rgbMin) / rgbRange); - const uA = floatToUint8(a); - const i4 = index * 4; - packedSplats[i4] = uR | (uG << 8) | (uB << 16) | (uA << 24); -} - -// Encode the RGB color in the packedSplats Uint32Array, leaving other fields alone. -export function setPackedSplatRgb( - packedSplats: Uint32Array, - index: number, - r: number, - g: number, - b: number, - encoding?: { - rgbMin?: number; - rgbMax?: number; - }, -) { - const rgbMin = encoding?.rgbMin ?? 0.0; - const rgbMax = encoding?.rgbMax ?? 1.0; - const rgbRange = rgbMax - rgbMin; - const uR = floatToUint8((r - rgbMin) / rgbRange); - const uG = floatToUint8((g - rgbMin) / rgbRange); - const uB = floatToUint8((b - rgbMin) / rgbRange); - - const i4 = index * 4; - packedSplats[i4] = - uR | (uG << 8) | (uB << 16) | (packedSplats[i4] & 0xff000000); -} - -// Encode the opacity in the packedSplats Uint32Array, leaving other fields alone. -export function setPackedSplatOpacity( - packedSplats: Uint32Array, - index: number, - opacity: number, -) { - const uA = floatToUint8(opacity); - - const i4 = index * 4; - packedSplats[i4] = (packedSplats[i4] & 0x00ffffff) | (uA << 24); -} - -const packedCenter = new THREE.Vector3(); -const packedScales = new THREE.Vector3(); -const packedQuaternion = new THREE.Quaternion(); -const packedColor = new THREE.Color(); -const packedFields = { - center: packedCenter, - scales: packedScales, - quaternion: packedQuaternion, - color: packedColor, - opacity: 0.0, -}; - -// Unpack all components of a PackedSplat from the packedSplats Uint32Array into -// THREE.js vector objects. The returned objects will be reused each call. -export function unpackSplat( - packedSplats: Uint32Array, - index: number, - encoding?: { - rgbMin?: number; - rgbMax?: number; - lnScaleMin?: number; - lnScaleMax?: number; - }, -): { - center: THREE.Vector3; - scales: THREE.Vector3; - quaternion: THREE.Quaternion; - color: THREE.Color; - opacity: number; -} { - // Returns a static object which is reused each time - const result = packedFields; - - const i4 = index * 4; - const word0 = packedSplats[i4]; - const word1 = packedSplats[i4 + 1]; - const word2 = packedSplats[i4 + 2]; - const word3 = packedSplats[i4 + 3]; - - const rgbMin = encoding?.rgbMin ?? 0.0; - const rgbMax = encoding?.rgbMax ?? 1.0; - const rgbRange = rgbMax - rgbMin; - result.color.set( - rgbMin + ((word0 & 0xff) / 255) * rgbRange, - rgbMin + (((word0 >>> 8) & 0xff) / 255) * rgbRange, - rgbMin + (((word0 >>> 16) & 0xff) / 255) * rgbRange, - ); - result.opacity = ((word0 >>> 24) & 0xff) / 255; - result.center.set( - fromHalf(word1 & 0xffff), - fromHalf((word1 >>> 16) & 0xffff), - fromHalf(word2 & 0xffff), - ); - - const lnScaleMin = encoding?.lnScaleMin ?? LN_SCALE_MIN; - const lnScaleMax = encoding?.lnScaleMax ?? LN_SCALE_MAX; - const lnScaleScale = (lnScaleMax - lnScaleMin) / 254.0; - const uScalesX = word3 & 0xff; - result.scales.x = - uScalesX === 0 ? 0.0 : Math.exp(lnScaleMin + (uScalesX - 1) * lnScaleScale); - const uScalesY = (word3 >>> 8) & 0xff; - result.scales.y = - uScalesY === 0 ? 0.0 : Math.exp(lnScaleMin + (uScalesY - 1) * lnScaleScale); - const uScalesZ = (word3 >>> 16) & 0xff; - result.scales.z = - uScalesZ === 0 ? 0.0 : Math.exp(lnScaleMin + (uScalesZ - 1) * lnScaleScale); - - const uQuat = ((word2 >>> 16) & 0xffff) | ((word3 >>> 8) & 0xff0000); - decodeQuatOctXy88R8(uQuat, result.quaternion); - // decodeQuatXyz888(uQuat, result.quaternion); - // decodeQuatEulerXyz888(uQuat, result.quaternion); - - return result; -} - // Compute a texture array size that is large enough to fit numSplats. The most // common 2D texture size in WebGL2 is 4096x4096 which only allows for 16M splats, // so Spark stores Gsplat data in a 2D texture array, which most platforms support @@ -730,208 +234,27 @@ export function computeMaxSplats(numSplats: number): number { return width * height * depth; } -// Heuristic function to determine if we are running on a mobile device. -export function isMobile(): boolean { - if (navigator.maxTouchPoints > 0) { - // Touch-enabled device, assume it's mobile - return true; - } - return /Mobi|Android|iPhone|iPad|iPod|Opera Mini|IEMobile/.test( - navigator.userAgent, - ); -} - -// Heuristic function to determine if we are running on an Android device. -// (does not include Oculus Quest) -export function isAndroid(): boolean { - return /Android/.test(navigator.userAgent); -} - -// Heuristic function to determine if we are running on an Oculus Quest device. -export function isOculus(): boolean { - return /Oculus/.test(navigator.userAgent); -} - -// Take an array of RGBA8 encoded pixels and flip them vertically in-place. -// This is useful for converting between top-left and bottom-left coordinate systems -// in standard 2D images vs WebGL2. -export function flipPixels( - pixels: Uint8Array, - width: number, - height: number, -): Uint8Array { - // Flips pixels vertically in-place, returns original array. - const tempLine = new Uint8Array(width * 4); - - // Only need to process half the height since we're swapping - for (let y = 0; y < height / 2; y++) { - const topOffset = y * width * 4; - const bottomOffset = (height - 1 - y) * width * 4; - - // Save top line to temp buffer - tempLine.set(pixels.subarray(topOffset, topOffset + width * 4)); - // Move bottom line to top - pixels.set( - pixels.subarray(bottomOffset, bottomOffset + width * 4), - topOffset, - ); - // Move saved top line to bottom - pixels.set(tempLine, bottomOffset); - } - return pixels; -} - -// Utility to take an array of RGBA8 encoded pixels and convert them to a -// PNG-encoded image data URL that can be downloaded to the client. -export function pixelsToPngUrl( - pixels: Uint8Array, - width: number, - height: number, -): string { - const canvas = document.createElement("canvas"); - canvas.width = width; - canvas.height = height; - const ctx = canvas.getContext("2d"); - if (!ctx) { - throw new Error("Can't get 2d context"); - } - const imageData = ctx.createImageData(width, height); - imageData.data.set(pixels); - ctx.putImageData(imageData, 0, 0); - return canvas.toDataURL("image/png"); -} - -// Manually clone a THREE.Clock object. -export function cloneClock(clock: THREE.Clock): THREE.Clock { - const newClock = new THREE.Clock(clock.autoStart); - newClock.startTime = clock.startTime; - newClock.oldTime = clock.oldTime; - newClock.elapsedTime = clock.elapsedTime; - newClock.running = clock.running; - return newClock; -} - -// Utility to filter out an undefined values from an object. -export function omitUndefined(obj: T): Partial { - return Object.fromEntries( - Object.entries(obj).filter(([_, value]) => value !== undefined), - ) as Partial; -} - -// "Identity" vertex shader that just passes through the position. -export const IDENT_VERTEX_SHADER = unindent(` - precision highp float; - - in vec3 position; - - void main() { - gl_Position = vec4(position.xy, 0.0, 1.0); - } -`); - -// Returns the average position of an array of THREE.Vector3. -export function averagePositions(positions: THREE.Vector3[]): THREE.Vector3 { - const sum = new THREE.Vector3(); - for (const position of positions) { - sum.add(position); - } - return sum.divideScalar(positions.length); -} - -// Returns an "average" of an array of THREE.Quaternion objects. -// Note that this is not a spherical lerp between quaternions but -// rather an arithmetic mean that is normalized to unit length. -export function averageQuaternions( - quaternions: THREE.Quaternion[], -): THREE.Quaternion { - if (quaternions.length === 0) { - return new THREE.Quaternion(); - } - const sum = quaternions[0].clone(); - for (let i = 1; i < quaternions.length; i++) { - if (quaternions[i].dot(quaternions[0]) < 0.0) { - sum.x -= quaternions[i].x; - sum.y -= quaternions[i].y; - sum.z -= quaternions[i].z; - sum.w -= quaternions[i].w; - } else { - sum.x += quaternions[i].x; - sum.y += quaternions[i].y; - sum.z += quaternions[i].z; - sum.w += quaternions[i].w; - } - } - return sum.normalize(); -} - -// Compare two coordinates given by matrix1 and matrix2, returning the distance -// between their origins and the "coincidence" of their orientations, defined -// as the dot product of their "-z" axes. -export function coinciDist(matrix1: THREE.Matrix4, matrix2: THREE.Matrix4) { - const origin1 = new THREE.Vector3(0, 0, 0).applyMatrix4(matrix1); - const origin2 = new THREE.Vector3(0, 0, 0).applyMatrix4(matrix2); - const direction1 = new THREE.Vector3(0, 0, -1) - .applyMatrix4(matrix1) - .sub(origin1) - .normalize(); - const direction2 = new THREE.Vector3(0, 0, -1) - .applyMatrix4(matrix2) - .sub(origin2) - .normalize(); - - const distance = origin1.distanceTo(origin2); - const coincidence = direction1.dot(direction2); - return { distance, coincidence }; -} - -// Utility function that returns whether two coordinate system origins -// given by matrix1 and matrix2 are within a certain maxDistance of each other. -export function withinDist({ - matrix1, - matrix2, - maxDistance, -}: { - matrix1: THREE.Matrix4; - matrix2: THREE.Matrix4; - maxDistance: number; -}): boolean { - const origin1 = new THREE.Vector3(0, 0, 0).applyMatrix4(matrix1); - const origin2 = new THREE.Vector3(0, 0, 0).applyMatrix4(matrix2); - return origin1.distanceTo(origin2) <= maxDistance; -} - -// Utility function that returns whether two coordinate systems are "close" -// to each other, defined by a maxDistance and a minCoincidence. -export function withinCoinciDist({ - matrix1, - matrix2, - maxDistance, - minCoincidence, -}: { - matrix1: THREE.Matrix4; - matrix2: THREE.Matrix4; - maxDistance: number; - minCoincidence?: number; -}): boolean { - const { distance, coincidence } = coinciDist(matrix1, matrix2); - return ( - distance <= maxDistance && - (minCoincidence == null || coincidence >= minCoincidence) - ); -} +const tempTRS1 = { + position: new THREE.Vector3(), + rotation: new THREE.Quaternion(), + scale: new THREE.Vector3(), +}; +const tempTRS2 = { + position: new THREE.Vector3(), + rotation: new THREE.Quaternion(), + scale: new THREE.Vector3(), +}; // Compare two coordinate systems given by matrix1 and matrix2, returning the // distance between their origins and the "coorientation" of their orientations, // define as the dot product of their quaternion transforms (flipping their // orientation to be on the same hemisphere if necessary). -export function coorientDist(matrix1: THREE.Matrix4, matrix2: THREE.Matrix4) { - const [origin1, rotate1] = [new THREE.Vector3(), new THREE.Quaternion()]; - const [origin2, rotate2] = [new THREE.Vector3(), new THREE.Quaternion()]; - matrix1.decompose(origin1, rotate1, new THREE.Vector3()); - matrix2.decompose(origin2, rotate2, new THREE.Vector3()); +function coorientDist(matrix1: THREE.Matrix4, matrix2: THREE.Matrix4) { + matrix1.decompose(tempTRS1.position, tempTRS1.rotation, tempTRS1.scale); + matrix2.decompose(tempTRS2.position, tempTRS2.rotation, tempTRS2.scale); - const distance = origin1.distanceTo(origin2); - const coorient = Math.abs(rotate1.dot(rotate2)); + const distance = tempTRS1.position.distanceTo(tempTRS2.position); + const coorient = Math.abs(tempTRS1.rotation.dot(tempTRS2.rotation)); return { distance, coorient }; } @@ -954,42 +277,6 @@ export function withinCoorientDist({ ); } -// Like Math.sign but with a custom epsilon value. -export function epsilonSign(value: number, epsilon = 0.001): number { - if (Math.abs(value) < epsilon) { - return 0; - } - return Math.sign(value); -} - -// Encode a THREE.Quaternion into a 24-bit integer, converting the xyz coordinates -// to signed 8-bit integers (w can be derived from xyz), and flipping the sign -// of the quaternion if necessary to make this possible (q == -q for quaternions). -export function encodeQuatXyz888(q: THREE.Quaternion): number { - const negQuat = q.w < 0.0; - const iQuatX = floatToSint8(negQuat ? -q.x : q.x); - const iQuatY = floatToSint8(negQuat ? -q.y : q.y); - const iQuatZ = floatToSint8(negQuat ? -q.z : q.z); - const uQuatX = iQuatX & 0xff; - const uQuatY = iQuatY & 0xff; - const uQuatZ = iQuatZ & 0xff; - return uQuatX | (uQuatY << 8) | (uQuatZ << 16); -} - -// Decode a 24-bit integer of the quaternion's xyz coordinates into a THREE.Quaternion. -export function decodeQuatXyz888( - encoded: number, - out: THREE.Quaternion, -): THREE.Quaternion { - const iQuatX = (encoded << 24) >> 24; - const iQuatY = (encoded << 16) >> 24; - const iQuatZ = (encoded << 8) >> 24; - out.set(iQuatX / 127.0, iQuatY / 127.0, iQuatZ / 127.0, 0.0); - const dotSelf = out.x * out.x + out.y * out.y + out.z * out.z; - out.w = Math.sqrt(Math.max(0.0, 1.0 - dotSelf)); - return out; -} - // Temporary storage used in `encodeQuatOCtXy88R8` and `decodeQuatOctXy88R8` to // avoid allocation new Quaternions and Vector3 instances. const tempNormalizedQuaternion = new THREE.Quaternion(); @@ -1083,222 +370,6 @@ export function decodeQuatOctXy88R8( return out; } -/** - * Encodes a THREE.Quaternion into a 24‑bit unsigned integer - * by converting it to Euler angles (roll, pitch, yaw). - * The Euler angles are assumed to be in radians in the range [-π, π]. - * Each angle is normalized to [0,1] and quantized to 8 bits. - * Bit layout (LSB→MSB): - * - Bits 0–7: roll (quantized) - * - Bits 8–15: pitch (quantized) - * - Bits 16–23: yaw (quantized) - */ -export function encodeQuatEulerXyz888(q: THREE.Quaternion): number { - // Normalize quaternion to ensure a proper rotation. - const qNorm = q.clone().normalize(); - - // Tait–Bryan angles (roll, pitch, yaw) - const sinr_cosp = 2.0 * (qNorm.w * qNorm.x + qNorm.y * qNorm.z); - const cosr_cosp = 1.0 - 2.0 * (qNorm.x * qNorm.x + qNorm.y * qNorm.y); - const roll = Math.atan2(sinr_cosp, cosr_cosp); - - const sinp = 2.0 * (qNorm.w * qNorm.y - qNorm.z * qNorm.x); - const pitch = - Math.abs(sinp) >= 1.0 ? Math.sign(sinp) * (Math.PI / 2) : Math.asin(sinp); - - const siny_cosp = 2.0 * (qNorm.w * qNorm.z + qNorm.x * qNorm.y); - const cosy_cosp = 1.0 - 2.0 * (qNorm.y * qNorm.y + qNorm.z * qNorm.z); - const yaw = Math.atan2(siny_cosp, cosy_cosp); - - // Map each angle from [-π, π] to [0, 1] - const normRoll = (roll + Math.PI) / (2 * Math.PI); - const normPitch = (pitch + Math.PI) / (2 * Math.PI); - const normYaw = (yaw + Math.PI) / (2 * Math.PI); - - // Quantize to 8 bits (0 to 255) - const rollQ = Math.round(normRoll * 255); - const pitchQ = Math.round(normPitch * 255); - const yawQ = Math.round(normYaw * 255); - - // Pack into a 24-bit unsigned integer: - // Bits 0–7: rollQ, Bits 8–15: pitchQ, Bits 16–23: yawQ. - return (yawQ << 16) | (pitchQ << 8) | rollQ; -} - -/** - * Decodes a 24‑bit unsigned integer into a THREE.Quaternion - * by unpacking three 8‑bit values (roll, pitch, yaw) in the range [0,255] - * and then converting them back to Euler angles in [-π, π] and to a quaternion. - */ -export function decodeQuatEulerXyz888( - encoded: number, - out: THREE.Quaternion, -): THREE.Quaternion { - // Unpack 8‑bit values. - const rollQ = encoded & 0xff; - const pitchQ = (encoded >>> 8) & 0xff; - const yawQ = (encoded >>> 16) & 0xff; - - // Convert quantized values back to normalized [0,1] values. - const normRoll = rollQ / 255; - const normPitch = pitchQ / 255; - const normYaw = yawQ / 255; - - // Map from [0,1] to [-π, π] - const roll = normRoll * (2 * Math.PI) - Math.PI; - const pitch = normPitch * (2 * Math.PI) - Math.PI; - const yaw = normYaw * (2 * Math.PI) - Math.PI; - - // Convert Euler angles to quaternion (Tait–Bryan: roll, pitch, yaw). - const cr = Math.cos(roll * 0.5); - const sr = Math.sin(roll * 0.5); - const cp = Math.cos(pitch * 0.5); - const sp = Math.sin(pitch * 0.5); - const cy = Math.cos(yaw * 0.5); - const sy = Math.sin(yaw * 0.5); - - out.w = cr * cp * cy + sr * sp * sy; - out.x = sr * cp * cy - cr * sp * sy; - out.y = cr * sp * cy + sr * cp * sy; - out.z = cr * cp * sy - sr * sp * cy; - out.normalize(); - return out; -} - -// Pack four signed 8-bit values into a single uint32. -function packSint8Bytes( - b0: number, - b1: number, - b2: number, - b3: number, -): number { - const clampedB0 = Math.max(-127, Math.min(127, b0 * 127)); - const clampedB1 = Math.max(-127, Math.min(127, b1 * 127)); - const clampedB2 = Math.max(-127, Math.min(127, b2 * 127)); - const clampedB3 = Math.max(-127, Math.min(127, b3 * 127)); - return ( - (clampedB0 & 0xff) | - ((clampedB1 & 0xff) << 8) | - ((clampedB2 & 0xff) << 16) | - ((clampedB3 & 0xff) << 24) - ); -} - -// Encode an array of 9 signed RGB SH1 coefficients (clamped to [-1,1]) into -// a pair of uint32 values, where each coefficient is stored as a sint7 -export function encodeSh1Rgb( - sh1Array: Uint32Array, - index: number, - sh1Rgb: Float32Array, - encoding?: { - sh1Min?: number; - sh1Max?: number; - }, -) { - const sh1Min = encoding?.sh1Min ?? -1; - const sh1Max = encoding?.sh1Max ?? 1; - const sh1Mid = 0.5 * (sh1Min + sh1Max); - const sh1Scale = 126 / (sh1Max - sh1Min); - - // Pack sint7 values into 2 x uint32 - const base = index * 2; - for (let i = 0; i < 9; ++i) { - const s = (sh1Rgb[i] - sh1Mid) * sh1Scale; - const value = Math.round(Math.max(-63, Math.min(63, s))) & 0x7f; - const bitStart = i * 7; - const bitEnd = bitStart + 7; - - const wordStart = Math.floor(bitStart / 32); - const bitOffset = bitStart - wordStart * 32; - const firstWord = (value << bitOffset) & 0xffffffff; - sh1Array[base + wordStart] |= firstWord; - - if (bitEnd > wordStart * 32 + 32) { - const secondWord = (value >>> (32 - bitOffset)) & 0xffffffff; - sh1Array[base + wordStart + 1] |= secondWord; - } - } -} - -// Encode an array of 15 signed RGB SH2 coefficients (clamped to [-1,1]) into -// an array of 4 uint32 values, where each coefficient is stored as a sint8. -export function encodeSh2Rgb( - sh2Array: Uint32Array, - index: number, - sh2Rgb: Float32Array, - encoding?: { - sh2Min?: number; - sh2Max?: number; - }, -) { - const sh2Min = encoding?.sh2Min ?? -1; - const sh2Max = encoding?.sh2Max ?? 1; - const sh2Mid = 0.5 * (sh2Min + sh2Max); - const sh2Scale = 2 / (sh2Max - sh2Min); - - // Pack sint8 values into 4 x uint32 - sh2Array[index * 4 + 0] = packSint8Bytes( - (sh2Rgb[0] - sh2Mid) * sh2Scale, - (sh2Rgb[1] - sh2Mid) * sh2Scale, - (sh2Rgb[2] - sh2Mid) * sh2Scale, - (sh2Rgb[3] - sh2Mid) * sh2Scale, - ); - sh2Array[index * 4 + 1] = packSint8Bytes( - (sh2Rgb[4] - sh2Mid) * sh2Scale, - (sh2Rgb[5] - sh2Mid) * sh2Scale, - (sh2Rgb[6] - sh2Mid) * sh2Scale, - (sh2Rgb[7] - sh2Mid) * sh2Scale, - ); - sh2Array[index * 4 + 2] = packSint8Bytes( - (sh2Rgb[8] - sh2Mid) * sh2Scale, - (sh2Rgb[9] - sh2Mid) * sh2Scale, - (sh2Rgb[10] - sh2Mid) * sh2Scale, - (sh2Rgb[11] - sh2Mid) * sh2Scale, - ); - sh2Array[index * 4 + 3] = packSint8Bytes( - (sh2Rgb[12] - sh2Mid) * sh2Scale, - (sh2Rgb[13] - sh2Mid) * sh2Scale, - (sh2Rgb[14] - sh2Mid) * sh2Scale, - 0, - ); -} - -// Encode an array of 21 signed RGB SH3 coefficients (clamped to [-1,1]) into -// an array of 4 uint32 values, where each coefficient is stored as a sint6. -export function encodeSh3Rgb( - sh3Array: Uint32Array, - index: number, - sh3Rgb: Float32Array, - encoding?: { - sh3Min?: number; - sh3Max?: number; - }, -) { - const sh3Min = encoding?.sh3Min ?? -1; - const sh3Max = encoding?.sh3Max ?? 1; - const sh3Mid = 0.5 * (sh3Min + sh3Max); - const sh3Scale = 62 / (sh3Max - sh3Min); - - // Pack sint6 values into 4 x uint32 - const base = index * 4; - for (let i = 0; i < 21; ++i) { - const s = (sh3Rgb[i] - sh3Mid) * sh3Scale; - const value = Math.round(Math.max(-31, Math.min(31, s))) & 0x3f; - const bitStart = i * 6; - const bitEnd = bitStart + 6; - - const wordStart = Math.floor(bitStart / 32); - const bitOffset = bitStart - wordStart * 32; - const firstWord = (value << bitOffset) & 0xffffffff; - sh3Array[base + wordStart] |= firstWord; - - if (bitEnd > wordStart * 32 + 32) { - const secondWord = (value >>> (32 - bitOffset)) & 0xffffffff; - sh3Array[base + wordStart + 1] |= secondWord; - } - } -} - // Partially decompress a gzip-encoded Uint8Array, returning a Uint8Array of // the specified numBytes from the start of the file. export function decompressPartialGzip( @@ -1341,19 +412,11 @@ export function decompressPartialGzip( } export class GunzipReader { - fileBytes: Uint8Array; - chunkBytes: number; - - chunks: Uint8Array[]; - totalBytes: number; - reader: ReadableStreamDefaultReader; + private chunks: Uint8Array[]; + private totalBytes: number; + private reader: ReadableStreamDefaultReader; - constructor({ - fileBytes, - chunkBytes = 64 * 1024, - }: { fileBytes: Uint8Array; chunkBytes?: number }) { - this.fileBytes = fileBytes; - this.chunkBytes = chunkBytes; + constructor(fileBytes: Uint8Array) { this.chunks = []; this.totalBytes = 0; diff --git a/src/vrButton.ts b/src/vrButton.ts deleted file mode 100644 index 2bd2d8a..0000000 --- a/src/vrButton.ts +++ /dev/null @@ -1,164 +0,0 @@ -import type * as THREE from "three"; - -export class VRButton { - static createButton( - renderer: THREE.WebGLRenderer, - sessionInit: XRSessionInit = {}, - ): HTMLElement | null { - const navigatorXr = navigator.xr; - if (!navigatorXr) { - // Only allow creation if WebXR is supported - return null; - } - const xr = navigatorXr; - - const button = document.createElement("button"); - renderer.xr.enabled = true; - renderer.xr.setReferenceSpaceType("local"); - - function showEnterVR(/*device*/) { - let currentSession: XRSession | null = null; - - async function onSessionStarted(session: XRSession) { - console.log("onSessionStarted"); - - session.addEventListener("end", onSessionEnded); - - await renderer.xr.setSession(session); - button.textContent = "EXIT VR"; - - currentSession = session; - } - - function onSessionEnded(/*event*/) { - console.log("onSessionEnded"); - currentSession?.removeEventListener("end", onSessionEnded); - - button.textContent = "ENTER VR"; - - currentSession = null; - } - - button.style.display = ""; - button.style.cursor = "pointer"; - button.style.left = "calc(50% - 100px)"; - button.style.width = "200px"; - button.style.height = "100px"; - button.textContent = "ENTER VR"; - - // WebXR's requestReferenceSpace only works if the corresponding feature - // was requested at session creation time. For simplicity, just ask for - // the interesting ones as optional features, but be aware that the - // requestReferenceSpace call will fail if it turns out to be unavailable. - // ('local' is always available for immersive sessions and doesn't need to - // be requested separately.) - - const sessionOptions: XRSessionInit = { - ...sessionInit, - optionalFeatures: [ - // "local-floor", - // "bounded-floor", - // "layers", - ...(sessionInit.optionalFeatures || []), - ], - }; - - button.onmouseenter = () => { - button.style.opacity = "1.0"; - }; - button.onmouseleave = () => { - button.style.opacity = "0.5"; - }; - button.onclick = () => { - if (currentSession === null) { - console.log("requesting session"); - xr.requestSession("immersive-vr", sessionOptions).then( - onSessionStarted, - ); - // xr.requestSession( "immersive-ar", sessionOptions ).then( onSessionStarted ); - } else { - console.log("ending session"); - currentSession.end(); - } - }; - } - - function disableButton() { - button.style.display = "none"; - button.style.cursor = "auto"; - button.style.left = "calc(50% - 75px)"; - button.style.width = "150px"; - - button.onmouseenter = null; - button.onmouseleave = null; - button.onclick = null; - } - - function showWebXRNotFound() { - disableButton(); - button.textContent = "VR NOT SUPPORTED"; - } - - function showVRNotAllowed(exception: any) { - disableButton(); - console.warn( - "Exception when trying to call xr.isSessionSupported", - exception, - ); - button.textContent = "VR NOT ALLOWED"; - } - - function stylizeElement(element: HTMLElement) { - element.style.position = "absolute"; - element.style.bottom = "20px"; - element.style.padding = "12px 6px"; - element.style.border = "1px solid #fff"; - element.style.borderRadius = "4px"; - element.style.background = "rgba(0,0,0,0.1)"; - element.style.color = "#fff"; - element.style.font = "normal 13px sans-serif"; - element.style.textAlign = "center"; - element.style.opacity = "0.5"; - element.style.outline = "none"; - element.style.zIndex = "999"; - } - - button.id = "VRButton"; - button.style.display = "none"; - stylizeElement(button); - - xr.isSessionSupported("immersive-vr") - .then((supported) => { - // xr.isSessionSupported( "immersive-ar" ).then( function ( supported ) { - supported ? showEnterVR() : showWebXRNotFound(); - - if (supported && VRButton.xrSessionIsGranted) { - button.click(); - } - }) - .catch(showVRNotAllowed); - - return button; - } - - static registerSessionGrantedListener() { - const navigatorXr = navigator.xr; - if (!navigatorXr) { - // Only allow creation if WebXR is supported - return null; - } - const xr = navigatorXr; - - // WebXRViewer (based on Firefox) has a bug where addEventListener - // throws a silent exception and aborts execution entirely. - if (/WebXRViewer\//i.test(navigator.userAgent)) return; - - xr.addEventListener("sessiongranted", () => { - VRButton.xrSessionIsGranted = true; - }); - } - - static xrSessionIsGranted = false; -} - -VRButton.registerSessionGrantedListener(); diff --git a/src/worker.ts b/src/worker.ts deleted file mode 100644 index bd1efb0..0000000 --- a/src/worker.ts +++ /dev/null @@ -1,681 +0,0 @@ -import init_wasm, { sort_splats, sort32_splats } from "spark-internal-rs"; -import type { SplatEncoding } from "./PackedSplats"; -import type { PcSogsJson, TranscodeSpzInput } from "./SplatLoader"; -import { unpackAntiSplat } from "./antisplat"; -import { WASM_SPLAT_SORT } from "./defines"; -import { unpackKsplat } from "./ksplat"; -import { unpackPcSogs, unpackPcSogsZip } from "./pcsogs"; -import { PlyReader } from "./ply"; -import { SpzReader, transcodeSpz } from "./spz"; -import { - computeMaxSplats, - encodeSh1Rgb, - encodeSh2Rgb, - encodeSh3Rgb, - getArrayBuffers, - setPackedSplat, - setPackedSplatCenter, - setPackedSplatOpacity, - setPackedSplatQuat, - setPackedSplatRgb, - setPackedSplatScales, - toHalf, -} from "./utils"; - -// WebWorker for Spark's background CPU tasks, such as Gsplat file decoding -// and sorting. - -async function onMessage(event: MessageEvent) { - // Unpack RPC function name, arguments, and ID from the main thread. - const { name, args, id }: { name: string; args: unknown; id: number } = - event.data; - // console.log(`worker.onMessage(${id}, ${name}):`, args); - - // Initialize return result/error, to be filled out below. - let result = undefined; - let error = undefined; - - try { - switch (name) { - case "unpackPly": { - const { packedArray, fileBytes, splatEncoding } = args as { - packedArray: Uint32Array; - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = await unpackPly({ - packedArray, - fileBytes, - splatEncoding, - }); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodeSpz": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = await unpackSpz(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodeAntiSplat": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = unpackAntiSplat(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - }; - break; - } - case "decodeKsplat": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = unpackKsplat(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodePcSogs": { - const { fileBytes, extraFiles, splatEncoding } = args as { - fileBytes: Uint8Array; - extraFiles: Record; - splatEncoding: SplatEncoding; - }; - const json = JSON.parse( - new TextDecoder().decode(fileBytes), - ) as PcSogsJson; - const decoded = await unpackPcSogs(json, extraFiles, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "decodePcSogsZip": { - const { fileBytes, splatEncoding } = args as { - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; - }; - const decoded = await unpackPcSogsZip(fileBytes, splatEncoding); - result = { - id, - numSplats: decoded.numSplats, - packedArray: decoded.packedArray, - extra: decoded.extra, - }; - break; - } - case "sortSplats": { - // Sort maxSplats splats using readback data, which encodes one uint32 per - // Gsplats, with the low bytes encoding a float16 distance sort metric. - const { maxSplats, totalSplats, readback, ordering } = args as { - maxSplats: number; - totalSplats: number; - readback: Uint8Array[]; - ordering: Uint32Array; - }; - // Sort totalSplats splats each with 4 bytes of readback, and outputs Uint32Array ordering of splat indices - result = { - id, - readback, - ...sortSplats({ totalSplats, readback, ordering }), - }; - break; - } - case "sortDoubleSplats": { - // Sort numSplats splats using the readback distance metric, which encodes - // one float16 per splat (no unused high bytes like for sortSplats). - const { numSplats, readback, ordering } = args as { - numSplats: number; - readback: Uint16Array; - ordering: Uint32Array; - }; - if (WASM_SPLAT_SORT) { - result = { - id, - readback, - ordering, - activeSplats: sort_splats(numSplats, readback, ordering), - }; - } else { - result = { - id, - readback, - ...sortDoubleSplats({ numSplats, readback, ordering }), - }; - } - break; - } - case "sort32Splats": { - const { maxSplats, numSplats, readback, ordering } = args as { - maxSplats: number; - numSplats: number; - readback: Uint32Array; - ordering: Uint32Array; - }; - // Benchmark sort - // benchmarkSort(numSplats, readback, ordering); - if (WASM_SPLAT_SORT) { - result = { - id, - readback, - ordering, - activeSplats: sort32_splats(numSplats, readback, ordering), - }; - } else { - result = { - id, - readback, - ...sort32Splats({ maxSplats, numSplats, readback, ordering }), - }; - } - break; - } - case "transcodeSpz": { - const input = args as TranscodeSpzInput; - const spzBytes = await transcodeSpz(input); - result = { - id, - fileBytes: spzBytes, - input, - }; - break; - } - default: { - throw new Error(`Unknown name: ${name}`); - } - } - } catch (e) { - error = e; - console.error(error); - } - - // Send the result or error back to the main thread, making sure to transfer any ArrayBuffers - self.postMessage( - { id, result, error }, - { transfer: getArrayBuffers(result) }, - ); -} - -function benchmarkSort( - numSplats: number, - readback32: Uint32Array, - ordering: Uint32Array, -) { - if (numSplats > 0) { - console.log("Running sort benchmark"); - const readbackF32 = new Float32Array(readback32.buffer); - const readback16 = new Uint16Array(readback32.length); - for (let i = 0; i < numSplats; ++i) { - readback16[i] = toHalf(readbackF32[i]); - } - - const WARMUP = 10; - for (let i = 0; i < WARMUP; ++i) { - const activeSplats = sort_splats(numSplats, readback16, ordering); - const activeSplats32 = sort32_splats(numSplats, readback32, ordering); - const results = sortDoubleSplats({ - numSplats, - readback: readback16, - ordering, - }); - const results32 = sort32Splats({ - maxSplats: numSplats, - numSplats, - readback: readback32, - ordering, - }); - } - - const TIMING_SAMPLES = 1000; - let start: number; - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const activeSplats = sort_splats(numSplats, readback16, ordering); - } - const wasmTime = (performance.now() - start) / TIMING_SAMPLES; - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const results = sortDoubleSplats({ - numSplats, - readback: readback16, - ordering, - }); - } - const jsTime = (performance.now() - start) / TIMING_SAMPLES; - - console.log( - `JS: ${jsTime} ms, WASM: ${wasmTime} ms, numSplats: ${numSplats}`, - ); - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const activeSplats32 = sort32_splats(numSplats, readback32, ordering); - } - const wasm32Time = (performance.now() - start) / TIMING_SAMPLES; - - start = performance.now(); - for (let i = 0; i < TIMING_SAMPLES; ++i) { - const results = sort32Splats({ - maxSplats: numSplats, - numSplats, - readback: readback32, - ordering, - }); - } - const js32Time = (performance.now() - start) / TIMING_SAMPLES; - - console.log( - `JS32: ${js32Time} ms, WASM32: ${wasm32Time} ms, numSplats: ${numSplats}`, - ); - } -} - -async function unpackPly({ - packedArray, - fileBytes, - splatEncoding, -}: { - packedArray: Uint32Array; - fileBytes: Uint8Array; - splatEncoding: SplatEncoding; -}): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { - const ply = new PlyReader({ fileBytes }); - await ply.parseHeader(); - const numSplats = ply.numSplats; - - const extra: Record = {}; - - ply.parseSplats( - ( - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - ) => { - setPackedSplat( - packedArray, - index, - x, - y, - z, - scaleX, - scaleY, - scaleZ, - quatX, - quatY, - quatZ, - quatW, - opacity, - r, - g, - b, - splatEncoding, - ); - }, - (index, sh1, sh2, sh3) => { - if (sh1) { - if (!extra.sh1) { - extra.sh1 = new Uint32Array(numSplats * 2); - } - encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1, splatEncoding); - } - if (sh2) { - if (!extra.sh2) { - extra.sh2 = new Uint32Array(numSplats * 4); - } - encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2, splatEncoding); - } - if (sh3) { - if (!extra.sh3) { - extra.sh3 = new Uint32Array(numSplats * 4); - } - encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3, splatEncoding); - } - }, - ); - - return { packedArray, numSplats, extra }; -} - -async function unpackSpz( - fileBytes: Uint8Array, - splatEncoding: SplatEncoding, -): Promise<{ - packedArray: Uint32Array; - numSplats: number; - extra: Record; -}> { - const spz = new SpzReader({ fileBytes }); - await spz.parseHeader(); - const numSplats = spz.numSplats; - const maxSplats = computeMaxSplats(numSplats); - const packedArray = new Uint32Array(maxSplats * 4); - const extra: Record = {}; - - await spz.parseSplats( - (index, x, y, z) => { - setPackedSplatCenter(packedArray, index, x, y, z); - }, - (index, alpha) => { - setPackedSplatOpacity(packedArray, index, alpha); - }, - (index, r, g, b) => { - setPackedSplatRgb(packedArray, index, r, g, b, splatEncoding); - }, - (index, scaleX, scaleY, scaleZ) => { - setPackedSplatScales( - packedArray, - index, - scaleX, - scaleY, - scaleZ, - splatEncoding, - ); - }, - (index, quatX, quatY, quatZ, quatW) => { - setPackedSplatQuat(packedArray, index, quatX, quatY, quatZ, quatW); - }, - (index, sh1, sh2, sh3) => { - if (sh1) { - if (!extra.sh1) { - extra.sh1 = new Uint32Array(numSplats * 2); - } - encodeSh1Rgb(extra.sh1 as Uint32Array, index, sh1, splatEncoding); - } - if (sh2) { - if (!extra.sh2) { - extra.sh2 = new Uint32Array(numSplats * 4); - } - encodeSh2Rgb(extra.sh2 as Uint32Array, index, sh2, splatEncoding); - } - if (sh3) { - if (!extra.sh3) { - extra.sh3 = new Uint32Array(numSplats * 4); - } - encodeSh3Rgb(extra.sh3 as Uint32Array, index, sh3, splatEncoding); - } - }, - ); - return { packedArray, numSplats, extra }; -} - -// Array of buckets for sorting float16 distances with range [0, DEPTH_INFINITY]. -const DEPTH_INFINITY_F16 = 0x7c00; -const DEPTH_SIZE_16 = DEPTH_INFINITY_F16 + 1; -let depthArray16: Uint32Array | null = null; - -function sortSplats({ - totalSplats, - readback, - ordering, -}: { totalSplats: number; readback: Uint8Array[]; ordering: Uint32Array }): { - activeSplats: number; - ordering: Uint32Array; -} { - // Sort totalSplats Gsplats, each with 4 bytes of readback, and outputs Uint32Array - // of indices from most distant to nearest. Each 4 bytes encode a float16 distance - // and unused high bytes. - if (!depthArray16) { - depthArray16 = new Uint32Array(DEPTH_SIZE_16); - } - depthArray16.fill(0); - - const readbackUint32 = readback.map((layer) => new Uint32Array(layer.buffer)); - const layerSize = readbackUint32[0].length; - const numLayers = Math.ceil(totalSplats / layerSize); - - let layerBase = 0; - for (let layer = 0; layer < numLayers; ++layer) { - const readbackLayer = readbackUint32[layer]; - const layerSplats = Math.min(readbackLayer.length, totalSplats - layerBase); - for (let i = 0; i < layerSplats; ++i) { - const pri = readbackLayer[i] & 0x7fff; - if (pri < DEPTH_INFINITY_F16) { - depthArray16[pri] += 1; - } - } - layerBase += layerSplats; - } - - let activeSplats = 0; - for (let j = 0; j < DEPTH_SIZE_16; ++j) { - const nextIndex = activeSplats + depthArray16[j]; - depthArray16[j] = activeSplats; - activeSplats = nextIndex; - } - - layerBase = 0; - for (let layer = 0; layer < numLayers; ++layer) { - const readbackLayer = readbackUint32[layer]; - const layerSplats = Math.min(readbackLayer.length, totalSplats - layerBase); - for (let i = 0; i < layerSplats; ++i) { - const pri = readbackLayer[i] & 0x7fff; - if (pri < DEPTH_INFINITY_F16) { - ordering[depthArray16[pri]] = layerBase + i; - depthArray16[pri] += 1; - } - } - layerBase += layerSplats; - } - if (depthArray16[DEPTH_SIZE_16 - 1] !== activeSplats) { - throw new Error( - `Expected ${activeSplats} active splats but got ${depthArray16[DEPTH_SIZE_16 - 1]}`, - ); - } - - return { activeSplats, ordering }; -} - -// Sort numSplats splats, each with 2 bytes of float16 readback for distance metric, -// using one bucket sort pass, outputting Uint32Array of indices. -function sortDoubleSplats({ - numSplats, - readback, - ordering, -}: { numSplats: number; readback: Uint16Array; ordering: Uint32Array }): { - activeSplats: number; - ordering: Uint32Array; -} { - // Ensure depthArray is allocated and zeroed out for our buckets. - if (!depthArray16) { - depthArray16 = new Uint32Array(DEPTH_SIZE_16); - } - depthArray16.fill(0); - - // Count the number of splats in each bucket (cull Gsplats at infinity). - for (let i = 0; i < numSplats; ++i) { - const pri = readback[i]; - if (pri < DEPTH_INFINITY_F16) { - depthArray16[pri] += 1; - } - } - - // Compute the beginning index of each bucket in the output array and the - // total number of active (non-infinity) splats, going in reverse order - // because we want most distant Gsplats to be first in the output array. - let activeSplats = 0; - for (let j = DEPTH_INFINITY_F16 - 1; j >= 0; --j) { - const nextIndex = activeSplats + depthArray16[j]; - depthArray16[j] = activeSplats; - activeSplats = nextIndex; - } - - // Write out the sorted indices into the output array according - // bucket order. - for (let i = 0; i < numSplats; ++i) { - const pri = readback[i]; - if (pri < DEPTH_INFINITY_F16) { - ordering[depthArray16[pri]] = i; - depthArray16[pri] += 1; - } - } - // Sanity check that the end of the closest bucket is the same as - // our total count of active splats (not at infinity). - if (depthArray16[0] !== activeSplats) { - throw new Error( - `Expected ${activeSplats} active splats but got ${depthArray16[0]}`, - ); - } - - return { activeSplats, ordering }; -} - -const DEPTH_INFINITY_F32 = 0x7f800000; -let bucket16lo: Uint32Array | null = null; -let bucket16hi: Uint32Array | null = null; -let scratchSplats: Uint32Array | null = null; - -// two-pass radix sort (base 65536) of 32-bit keys in readback, -// but placing largest values first. -function sort32Splats({ - maxSplats, - numSplats, - readback, // Uint32Array of bit‑patterns - ordering, // Uint32Array to fill with sorted indices -}: { - maxSplats: number; - numSplats: number; - readback: Uint32Array; - ordering: Uint32Array; -}): { activeSplats: number; ordering: Uint32Array } { - const BASE = 1 << 16; // 65536 - - // allocate once - if (!bucket16lo) { - bucket16lo = new Uint32Array(BASE); - } - if (!bucket16hi) { - bucket16hi = new Uint32Array(BASE); - } - if (!scratchSplats || scratchSplats.length < maxSplats) { - scratchSplats = new Uint32Array(maxSplats); - } - - // tally low and high buckets - bucket16lo.fill(0); - bucket16hi.fill(0); - for (let i = 0; i < numSplats; ++i) { - const key = readback[i]; - if (key < DEPTH_INFINITY_F32) { - const inv = ~key >>> 0; - bucket16lo[inv & 0xffff] += 1; - bucket16hi[inv >>> 16] += 1; - } - } - - // - // ——— Pass #1: bucket by inv(lo 16 bits) ——— - // - // exclusive prefix‑sum → starting offsets - let total = 0; - for (let b = 0; b < BASE; ++b) { - const c = bucket16lo[b]; - bucket16lo[b] = total; - total += c; - } - const activeSplats = total; - - // scatter into scratch by low bits of inv - for (let i = 0; i < numSplats; ++i) { - const key = readback[i]; - if (key < DEPTH_INFINITY_F32) { - const inv = ~key >>> 0; - scratchSplats[bucket16lo[inv & 0xffff]++] = i; - } - } - - // - // ——— Pass #2: bucket by inv(hi 16 bits) ——— - // - // exclusive prefix‑sum again - let sum = 0; - for (let b = 0; b < BASE; ++b) { - const c = bucket16hi[b]; - bucket16hi[b] = sum; - sum += c; - } - - // scatter into final ordering by high bits of inv - for (let k = 0; k < activeSplats; ++k) { - const idx = scratchSplats[k]; - const inv = ~readback[idx] >>> 0; - ordering[bucket16hi[inv >>> 16]++] = idx; - } - - // sanity‑check: the last bucket should have eaten all entries - if (bucket16hi[BASE - 1] !== activeSplats) { - throw new Error( - `Expected ${activeSplats} active splats but got ${bucket16hi[BASE - 1]}`, - ); - } - - return { activeSplats, ordering }; -} - -// Buffer to queue any messages received while initializing, for example -// early messages to unpack a Gsplat file while still initializing the WASM code. -const messageBuffer: MessageEvent[] = []; - -function bufferMessage(event: MessageEvent) { - messageBuffer.push(event); -} - -async function initialize() { - // Hold any messages received while initializing - self.addEventListener("message", bufferMessage); - - await init_wasm(); - - self.removeEventListener("message", bufferMessage); - self.addEventListener("message", onMessage); - - // Process any buffered messages - for (const event of messageBuffer) { - onMessage(event); - } - messageBuffer.length = 0; -} - -initialize().catch(console.error); diff --git a/src/worker/sort.ts b/src/worker/sort.ts new file mode 100644 index 0000000..c332409 --- /dev/null +++ b/src/worker/sort.ts @@ -0,0 +1,242 @@ +import * as THREE from "three"; +import type { TransformRange } from "../defines"; +import { toHalf } from "../utils"; + +// Array of buckets for sorting float16 distances with range [0, DEPTH_INFINITY]. +const DEPTH_INFINITY_F16 = 0x7c00; +const DEPTH_SIZE_16 = DEPTH_INFINITY_F16 + 1; +let depthArray16: Uint32Array | null = null; + +// Sort numSplats splats, each with 2 bytes of float16 readback for distance metric, +// using one bucket sort pass, outputting Uint32Array of indices. +export function sortDoubleSplats({ + numSplats, + readback, + ordering, +}: { numSplats: number; readback: Uint16Array; ordering: Uint32Array }): { + activeSplats: number; + ordering: Uint32Array; +} { + // Ensure depthArray is allocated and zeroed out for our buckets. + if (!depthArray16) { + depthArray16 = new Uint32Array(DEPTH_SIZE_16); + } + depthArray16.fill(0); + + // Count the number of splats in each bucket (cull Gsplats at infinity). + for (let i = 0; i < numSplats; ++i) { + const pri = readback[i]; + if (pri < DEPTH_INFINITY_F16) { + depthArray16[pri] += 1; + } + } + + // Compute the beginning index of each bucket in the output array and the + // total number of active (non-infinity) splats, going in reverse order + // because we want most distant Gsplats to be first in the output array. + let activeSplats = 0; + for (let j = DEPTH_INFINITY_F16 - 1; j >= 0; --j) { + const nextIndex = activeSplats + depthArray16[j]; + depthArray16[j] = activeSplats; + activeSplats = nextIndex; + } + + // Write out the sorted indices into the output array according + // bucket order. + for (let i = 0; i < numSplats; ++i) { + const pri = readback[i]; + if (pri < DEPTH_INFINITY_F16) { + ordering[depthArray16[pri]] = i; + depthArray16[pri] += 1; + } + } + // Sanity check that the end of the closest bucket is the same as + // our total count of active splats (not at infinity). + if (depthArray16[0] !== activeSplats) { + throw new Error( + `Expected ${activeSplats} active splats but got ${depthArray16[0]}`, + ); + } + + return { activeSplats, ordering }; +} + +const DEPTH_INFINITY_F32 = 0x7f800000; +let bucket16lo: Uint32Array | null = null; +let bucket16hi: Uint32Array | null = null; +let scratchSplats: Uint32Array | null = null; + +// two-pass radix sort (base 65536) of 32-bit keys in readback, +// but placing largest values first. +export function sort32Splats({ + maxSplats, + numSplats, + readback, // Uint32Array of bit‑patterns + ordering, // Uint32Array to fill with sorted indices +}: { + maxSplats: number; + numSplats: number; + readback: Uint32Array; + ordering: Uint32Array; +}): { activeSplats: number; ordering: Uint32Array } { + const BASE = 1 << 16; // 65536 + + // allocate once + if (!bucket16lo) { + bucket16lo = new Uint32Array(BASE); + } + if (!bucket16hi) { + bucket16hi = new Uint32Array(BASE); + } + if (!scratchSplats || scratchSplats.length < maxSplats) { + scratchSplats = new Uint32Array(maxSplats); + } + + // tally low and high buckets + bucket16lo.fill(0); + bucket16hi.fill(0); + for (let i = 0; i < numSplats; ++i) { + const key = readback[i]; + if (key < DEPTH_INFINITY_F32) { + const inv = ~key >>> 0; + bucket16lo[inv & 0xffff] += 1; + bucket16hi[inv >>> 16] += 1; + } + } + + // + // ——— Pass #1: bucket by inv(lo 16 bits) ——— + // + // exclusive prefix‑sum → starting offsets + let total = 0; + for (let b = 0; b < BASE; ++b) { + const c = bucket16lo[b]; + bucket16lo[b] = total; + total += c; + } + const activeSplats = total; + + // scatter into scratch by low bits of inv + for (let i = 0; i < numSplats; ++i) { + const key = readback[i]; + if (key < DEPTH_INFINITY_F32) { + const inv = ~key >>> 0; + scratchSplats[bucket16lo[inv & 0xffff]++] = i; + } + } + + // + // ——— Pass #2: bucket by inv(hi 16 bits) ——— + // + // exclusive prefix‑sum again + let sum = 0; + for (let b = 0; b < BASE; ++b) { + const c = bucket16hi[b]; + bucket16hi[b] = sum; + sum += c; + } + + // scatter into final ordering by high bits of inv + for (let k = 0; k < activeSplats; ++k) { + const idx = scratchSplats[k]; + const inv = ~readback[idx] >>> 0; + ordering[bucket16hi[inv >>> 16]++] = idx; + } + + // sanity‑check: the last bucket should have eaten all entries + if (bucket16hi[BASE - 1] !== activeSplats) { + throw new Error( + `Expected ${activeSplats} active splats but got ${bucket16hi[BASE - 1]}`, + ); + } + + return { activeSplats, ordering }; +} + +// FIXME: Avoid importing THREE classes into worker +let distances = new Uint16Array(); + +const centerVector = new THREE.Vector3(); +const transformMatrix = new THREE.Matrix4(); +const viewOriginVector = new THREE.Vector3(); +const viewDirVector = new THREE.Vector3(); + +export function sortSplatsCpu( + splatCenters: Float32Array, + transforms: Array, + viewOrigin: [number, number, number], + viewDir: [number, number, number], + ordering: Uint32Array, +): { activeSplats: number; ordering: Uint32Array } { + const numSplats = splatCenters.length / 3; + if (distances.length < numSplats) { + distances = new Uint16Array(numSplats); + } + distances.fill(DEPTH_INFINITY_F16); + + viewOriginVector.fromArray(viewOrigin); + viewDirVector.fromArray(viewDir); + + // Ensure depthArray is allocated and zeroed out for our buckets. + if (!depthArray16) { + depthArray16 = new Uint32Array(DEPTH_SIZE_16); + } + depthArray16.fill(0); + + // Compute distance for each splat and count buckets + let transformIndex = 0; + while (transformIndex < transforms.length) { + const transform = transforms[transformIndex]; + transformMatrix.fromArray(transform.matrix); + + for ( + let splatIndex = transform.start; + splatIndex < transform.end; + ++splatIndex + ) { + // Apply transform to center + centerVector.fromArray(splatCenters, splatIndex * 3); + centerVector.applyMatrix4(transformMatrix); + const distance = centerVector.sub(viewOriginVector).dot(viewDirVector); + + if (distance >= 0) { + const distanceU16 = toHalf(distance); + distances[splatIndex] = distanceU16; + depthArray16[distanceU16] += 1; + } else { + distances[splatIndex] = DEPTH_INFINITY_F16; + } + } + + transformIndex++; + } + + // Compute the beginning index of each bucket in the output array and the + // total number of active (non-infinity) splats, going in reverse order + // because we want most distant Gsplats to be first in the output array. + let activeSplats = 0; + for (let j = DEPTH_INFINITY_F16 - 1; j >= 0; --j) { + const nextIndex = activeSplats + depthArray16[j]; + depthArray16[j] = activeSplats; + activeSplats = nextIndex; + } + + // Write out the sorted indices into the output array according + // bucket order. + for (let i = 0; i < numSplats; ++i) { + const pri = distances[i]; + if (pri < DEPTH_INFINITY_F16) { + ordering[depthArray16[pri]] = i; + depthArray16[pri] += 1; + } + } + // Sanity check that the end of the closest bucket is the same as + // our total count of active splats (not at infinity). + if (depthArray16[0] !== activeSplats) { + throw new Error( + `Expected ${activeSplats} active splats but got ${depthArray16[0]}`, + ); + } + + return { activeSplats, ordering }; +} diff --git a/src/worker/worker.ts b/src/worker/worker.ts new file mode 100644 index 0000000..2f8841b --- /dev/null +++ b/src/worker/worker.ts @@ -0,0 +1,222 @@ +import init_wasm, { sort_splats, sort32_splats } from "spark-internal-rs"; +import { type TransformRange, WASM_SPLAT_SORT } from "../defines"; +import { type UnpackResult, createSplatEncoder } from "../encoding/encoder"; +import { unpackAntiSplat } from "../formats/antisplat"; +import { unpackKsplat } from "../formats/ksplat"; +import { + type PcSogsJson, + type PcSogsV2Json, + unpackPcSogs, + unpackPcSogsZip, +} from "../formats/pcsogs"; +import { unpackPly } from "../formats/ply"; +import { unpackSpz } from "../formats/spz"; +import { getArrayBuffers } from "../utils"; +import { sort32Splats, sortDoubleSplats, sortSplatsCpu } from "./sort"; + +type RpcMethod = { args: Args; result: Result }; + +type DecodeArgs = { + fileBytes: Uint8Array; + extraFiles?: Record; + encoder: string; + encoderOptions?: Record; +}; +type SortArgs = { + numSplats: number; + maxSplats: number; + readback: Readback; + ordering: Uint32Array; +}; +type SortResult = { + ordering: Uint32Array; + readback: Readback; + activeSplats: number; +}; + +export type RpcMethods = { + decodePly: RpcMethod; + decodeSpz: RpcMethod; + decodeAntiSplat: RpcMethod; + decodeKsplat: RpcMethod; + decodePcSogs: RpcMethod; + decodePcSogsZip: RpcMethod; + sortDoubleSplats: RpcMethod< + SortArgs>, + SortResult> + >; + sort32Splats: RpcMethod< + SortArgs>, + SortResult> + >; + sortSplatsCpu: RpcMethod< + { + centers?: Float32Array; + transforms: Array; + viewOrigin: [number, number, number]; + viewDir: [number, number, number]; + ordering: Uint32Array; + }, + { + ordering: Uint32Array; + activeSplats: number; + } + >; +}; + +type RpcMessageEvent = MessageEvent< + { + [Method in keyof RpcMethods]: { + name: Method; + args: RpcMethods[Method]["args"]; + id: string; + }; + }[keyof RpcMethods] +>; + +// Worker local storage of splat centers for sorting +let splatCenters = new Float32Array(); + +/** + * WebWorker for Spark's background CPU tasks, such as Gsplat file decoding + * and sorting. + */ +async function onMessage(event: RpcMessageEvent) { + // Unpack RPC function name, arguments, and ID from the main thread. + const { name, args, id } = event.data; + // console.log(`worker.onMessage(${id}, ${name}):`, args); + + // Initialize return result/error, to be filled out below. + let result = undefined; + let error = undefined; + + try { + if (name === "sortSplatsCpu") { + // Check if new centers are provided + if (args.centers) { + splatCenters = args.centers; + } + + result = { + id, + ...sortSplatsCpu( + splatCenters, + args.transforms, + args.viewOrigin, + args.viewDir, + args.ordering, + ), + }; + } else if (name === "sortDoubleSplats") { + // Sort numSplats splats using the readback distance metric, which encodes + // one float16 per splat (no unused high bytes like for sortSplats). + const { numSplats, readback, ordering } = args; + if (WASM_SPLAT_SORT) { + result = { + id, + readback, + ordering, + activeSplats: sort_splats(numSplats, readback, ordering), + }; + } else { + result = { + id, + readback, + ...sortDoubleSplats({ numSplats, readback, ordering }), + }; + } + } else if (name === "sort32Splats") { + const { maxSplats, numSplats, readback, ordering } = args; + // Benchmark sort + // benchmarkSort(numSplats, readback, ordering); + if (WASM_SPLAT_SORT) { + result = { + id, + readback, + ordering, + activeSplats: sort32_splats(numSplats, readback, ordering), + }; + } else { + result = { + id, + readback, + ...sort32Splats({ maxSplats, numSplats, readback, ordering }), + }; + } + } else if (name.startsWith("decode")) { + // All decodeXyz functions follow the same signature + const { fileBytes, extraFiles, encoder, encoderOptions } = args; + const splatEncoder = createSplatEncoder(encoder, encoderOptions); + + let decoded: UnpackResult; + switch (name) { + case "decodePly": + decoded = await unpackPly(fileBytes, splatEncoder); + break; + case "decodeSpz": + decoded = await unpackSpz(fileBytes, splatEncoder); + break; + case "decodeAntiSplat": + decoded = unpackAntiSplat(fileBytes, splatEncoder); + break; + case "decodeKsplat": + decoded = unpackKsplat(fileBytes, splatEncoder); + break; + case "decodePcSogs": { + const json = JSON.parse(new TextDecoder().decode(fileBytes)) as + | PcSogsJson + | PcSogsV2Json; + decoded = await unpackPcSogs(json, extraFiles ?? {}, splatEncoder); + break; + } + case "decodePcSogsZip": + decoded = await unpackPcSogsZip(fileBytes, splatEncoder); + break; + default: + throw new Error(`Unknown decode name: ${name}`); + } + result = { + id, + numSplats: decoded.numSplats, + unpacked: decoded.unpacked, + }; + } else { + throw new Error(`Unknown name: ${name}`); + } + } catch (e) { + error = e; + console.error(error); + } + + // Send the result or error back to the main thread, making sure to transfer any ArrayBuffers + self.postMessage( + { id, result, error }, + { transfer: getArrayBuffers(result) }, + ); +} + +// Buffer to queue any messages received while initializing, for example +// early messages to unpack a Gsplat file while still initializing the WASM code. +const messageBuffer: MessageEvent[] = []; + +function bufferMessage(event: MessageEvent) { + messageBuffer.push(event); +} + +async function initialize() { + // Hold any messages received while initializing + self.addEventListener("message", bufferMessage); + + await init_wasm(); + + self.removeEventListener("message", bufferMessage); + self.addEventListener("message", onMessage); + + // Process any buffered messages + for (const event of messageBuffer) { + onMessage(event); + } + messageBuffer.length = 0; +} + +initialize().catch(console.error); diff --git a/tsconfig.json b/tsconfig.json index a8639b3..1b89feb 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,8 +1,8 @@ { "compilerOptions": { - "target": "es2020", + "target": "es2024", "module": "es2020", - "lib": ["ES2020", "DOM"], + "lib": ["ES2024", "DOM"], "esModuleInterop": true, "forceConsistentCasingInFileNames": true, "strict": true,