From b63bca64a6d8fa6f1615d411eb90f3d8cc20ad27 Mon Sep 17 00:00:00 2001 From: Mei Date: Wed, 22 Nov 2023 16:00:01 +0800 Subject: [PATCH 1/3] add expand.js and expand_test.js --- src/expand.js | 14 +++++++++++ test/expand_test.js | 58 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 src/expand.js create mode 100644 test/expand_test.js diff --git a/src/expand.js b/src/expand.js new file mode 100644 index 0000000..3a189fb --- /dev/null +++ b/src/expand.js @@ -0,0 +1,14 @@ +'use strict'; + +import {broadcast} from './lib/broadcast.js'; + +/** + * Expand any dimension of size 1 of the input tensor to a + * larger size according to the new shape. + * @param {Tensor} input + * @param {Array} newShape + * @return {Tensor} + */ +export function expand(input, newShape) { + return broadcast(input, newShape); +} diff --git a/test/expand_test.js b/test/expand_test.js new file mode 100644 index 0000000..200eef9 --- /dev/null +++ b/test/expand_test.js @@ -0,0 +1,58 @@ +'use strict'; + +import {expand} from '../src/expand.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test expand', function() { + function testExpand(input, newShape, expected) { + const tensor = new Tensor(input.shape, input.data); + const outputTensor = expand(tensor, newShape); + utils.checkShape(outputTensor, expected.shape); + utils.checkValue(outputTensor, expected.data); + } + + it('expand changed dimensions', function() { + const input = { + shape: [2, 3], + data: [ + 1, 1, 0, + 0, 1, 0, + ], + }; + + const newShape = [2, 2, 3]; + + const expected = { + shape: [2, 2, 3], + data: [ + 1, 1, 0, + 0, 1, 0, + 1, 1, 0, + 0, 1, 0, + ], + }; + testExpand(input, newShape, expected); + }); + + it('expand unchanged dimensions', function() { + const input = { + shape: [3, 1], + data: [ + 1, 2, 3, + ], + }; + + const newShape = [3, 4]; + + const expected = { + shape: [3, 4], + data: [ + 1, 1, 1, 1, + 2, 2, 2, 2, + 3, 3, 3, 3, + ], + }; + testExpand(input, newShape, expected); + }); +}); From 3feda156da756711398c55af6f64847780de8f43 Mon Sep 17 00:00:00 2001 From: Mei Date: Mon, 27 Nov 2023 11:14:49 +0800 Subject: [PATCH 2/3] modified expand.js and add more tests --- src/expand.js | 13 +++++-- test/expand_test.js | 88 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/src/expand.js b/src/expand.js index 3a189fb..cd3759a 100644 --- a/src/expand.js +++ b/src/expand.js @@ -1,6 +1,7 @@ 'use strict'; -import {broadcast} from './lib/broadcast.js'; +import {broadcast, getBroadcastShape} from './lib/broadcast.js'; +import {Tensor} from '../src/lib/tensor.js'; /** * Expand any dimension of size 1 of the input tensor to a @@ -10,5 +11,13 @@ import {broadcast} from './lib/broadcast.js'; * @return {Tensor} */ export function expand(input, newShape) { - return broadcast(input, newShape); + if (input.shape.length === 0) { + const inputReshape = new Tensor([1], input.data); + const outputShape = getBroadcastShape(inputReshape.shape, newShape); + return broadcast(inputReshape, outputShape); + } else { + const inputReshape = new Tensor(input.shape, input.data); + const outputShape = getBroadcastShape(inputReshape.shape, newShape); + return broadcast(inputReshape, outputShape); + } } diff --git a/test/expand_test.js b/test/expand_test.js index 200eef9..9e66610 100644 --- a/test/expand_test.js +++ b/test/expand_test.js @@ -12,30 +12,55 @@ describe('test expand', function() { utils.checkValue(outputTensor, expected.data); } - it('expand changed dimensions', function() { + it('expand a 3D input with a 4D newShape to a 4D output.', function() { const input = { - shape: [2, 3], + shape: [2, 1, 4], data: [ - 1, 1, 0, - 0, 1, 0, + 1, 2, 3, 4, 5, 6, 7, 8, ], }; - const newShape = [2, 2, 3]; + const newShape = [5, 1, 3, 4]; const expected = { - shape: [2, 2, 3], + shape: [5, 2, 3, 4], data: [ - 1, 1, 0, - 0, 1, 0, - 1, 1, 0, - 0, 1, 0, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, ], }; testExpand(input, newShape, expected); }); - it('expand unchanged dimensions', function() { + it('expand a 3D input with a 2D newShape to a 3D output', function() { + const input = { + shape: [2, 1, 4], + data: [ + 1, 2, 3, 4, 5, 6, 7, 8, + ], + }; + + const newShape = [3, 1]; + + const expected = { + shape: [2, 3, 4], + data: [ + 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, + 5, 6, 7, 8, 5, 6, 7, 8, 5, 6, 7, 8, + ], + }; + testExpand(input, newShape, expected); + }); + + it('expand a 2D input with a 2D newShape to a 2D output', function() { const input = { shape: [3, 1], data: [ @@ -55,4 +80,45 @@ describe('test expand', function() { }; testExpand(input, newShape, expected); }); + + it('expand a 0D input with a 2D newShape to a 2D output.', function() { + const input = { + shape: [], + data: [ + 6, + ], + }; + + const newShape = [2, 3]; + + const expected = { + shape: [2, 3], + data: [ + 6, 6, 6, + 6, 6, 6, + ], + }; + testExpand(input, newShape, expected); + }); + + it('expand a 2D input with a 0D newShape to a 2D output.', function() { + const input = { + shape: [2, 3], + data: [ + 1, 2, 3, + 4, 5, 6, + ], + }; + + const newShape = []; + + const expected = { + shape: [2, 3], + data: [ + 1, 2, 3, + 4, 5, 6, + ], + }; + testExpand(input, newShape, expected); + }); }); From 6026dba5310ab0a4909b7cc30333f887ed1aa8ea Mon Sep 17 00:00:00 2001 From: Mei Date: Mon, 27 Nov 2023 15:45:51 +0800 Subject: [PATCH 3/3] modified expand.js --- src/expand.js | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/expand.js b/src/expand.js index cd3759a..2319adf 100644 --- a/src/expand.js +++ b/src/expand.js @@ -1,7 +1,7 @@ 'use strict'; import {broadcast, getBroadcastShape} from './lib/broadcast.js'; -import {Tensor} from '../src/lib/tensor.js'; +import {Scalar} from '../src/lib/tensor.js'; /** * Expand any dimension of size 1 of the input tensor to a @@ -10,14 +10,10 @@ import {Tensor} from '../src/lib/tensor.js'; * @param {Array} newShape * @return {Tensor} */ + + export function expand(input, newShape) { - if (input.shape.length === 0) { - const inputReshape = new Tensor([1], input.data); - const outputShape = getBroadcastShape(inputReshape.shape, newShape); - return broadcast(inputReshape, outputShape); - } else { - const inputReshape = new Tensor(input.shape, input.data); - const outputShape = getBroadcastShape(inputReshape.shape, newShape); - return broadcast(inputReshape, outputShape); - } + const inputReshape = input.shape.length === 0 ? new Scalar(input.data) : input; + const outputShape = getBroadcastShape(inputReshape.shape, newShape); + return broadcast(inputReshape, outputShape); }