From 5a88636e2ab71847e0818f098e30fffe6728dd22 Mon Sep 17 00:00:00 2001 From: anbang <343542678@qq.com> Date: Thu, 16 Sep 2021 09:40:43 +0800 Subject: [PATCH] fix(webgl): fix ops bugs --- .../paddlejs-backend-webgl/src/ops/index.ts | 3 +- .../src/ops/shader/greater_than.ts | 9 ++---- .../src/ops/shader/reduce_mean.ts | 28 +++++++++++++------ .../src/ops/shader/reduce_sum.ts | 26 +++++++++++------ .../src/ops/shader/where.ts | 7 ++--- .../src/opFactory/opDataBuilder.ts | 14 ++++++++-- 6 files changed, 56 insertions(+), 31 deletions(-) diff --git a/packages/paddlejs-backend-webgl/src/ops/index.ts b/packages/paddlejs-backend-webgl/src/ops/index.ts index 40db7ba8..b477f073 100644 --- a/packages/paddlejs-backend-webgl/src/ops/index.ts +++ b/packages/paddlejs-backend-webgl/src/ops/index.ts @@ -113,7 +113,8 @@ const ops = { pack_out, nhwc_2_nchw, feedPost, - imgFeed + imgFeed, + 'conv2d-elementwise_add-leaky_relu': conv2d_elementwise_add }; export { ops diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts b/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts index d6a63896..2c2ca25f 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/greater_than.ts @@ -3,19 +3,16 @@ * @file greater_than return x >= y */ -function mainFunc( - {}, - {} -) { +function mainFunc() { return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 - float x = getValueFromTensorPos_input(oPos.r, oPos.g, oPos.b, oPos.a); + float x = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); float y = getValueFromTensorPos_counter(oPos.r, oPos.g, oPos.b, oPos.a); - setOutput(bool(x >= y)); + setOutput(float(bool(x >= y))); } `; } diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts index 17cf2a22..105984f1 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_mean.ts @@ -4,20 +4,32 @@ */ function mainFunc( - {}, - { inputs_dim, dim } + { origin }, + { dim } ) { + const { total_shape, height_shape, width_shape, channel } = origin; + const batch_shape = total_shape / (width_shape * height_shape * channel); + const shape = [batch_shape, channel, height_shape, width_shape]; + let codeStr = ''; + for (let i = 0; i < dim.length; i++) { + for (let j = 0; j < shape[dim[i]]; j++) { + codeStr += ` + oPos[${dim[i]}] = ${j}; + o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); + `; + if (j === shape[dim[i]]) { + codeStr += `o / float(${j});`; + } + } + } + return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 float o = 0.0; - for (int i = 0; i < ${inputs_dim}; i++) { - oPos[${dim}] = i; - o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); - } - o = o / float(${inputs_dim}); + ${codeStr} setOutput(o); } `; @@ -32,6 +44,6 @@ export default { origin: ['getValueFromTensorPos'] }, behaviors: [ - 'normalizeDim' + ] }; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts index 87eed9e8..92c48d72 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/reduce_sum.ts @@ -4,19 +4,29 @@ */ function mainFunc( - {}, - { inputs_dim, dim } + { origin }, + { dim } ) { + const { total_shape, height_shape, width_shape, channel } = origin; + const batch_shape = total_shape / (width_shape * height_shape * channel); + const shape = [batch_shape, channel, height_shape, width_shape]; + let codeStr = ''; + for (let i = 0; i < dim.length; i++) { + for (let j = 0; j < shape[dim[i]]; j++) { + codeStr += ` + oPos[${dim[i]}] = ${j}; + o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); + `; + } + } + return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 float o = 0.0; - for (int i = 0; i < ${inputs_dim}; i++) { - oPos[${dim}] = i; - o += getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);; - } + ${codeStr} setOutput(float(o)); } `; @@ -30,7 +40,5 @@ export default { textureFuncConf: { origin: ['getValueFromTensorPos'] }, - behaviors: [ - 'normalizeDim' - ] + behaviors: [] }; diff --git a/packages/paddlejs-backend-webgl/src/ops/shader/where.ts b/packages/paddlejs-backend-webgl/src/ops/shader/where.ts index da3c26e8..befa0090 100644 --- a/packages/paddlejs-backend-webgl/src/ops/shader/where.ts +++ b/packages/paddlejs-backend-webgl/src/ops/shader/where.ts @@ -3,16 +3,13 @@ * @file where return condition ? x : y */ -function mainFunc( - {}, - {} -) { +function mainFunc() { return ` // start函数 void main(void) { ivec4 oPos = getOutputTensorPos(); // 输出坐标转换为输入坐标 - float x = getValueFromTensorPos_input(oPos.r, oPos.g, oPos.b, oPos.a); + float x = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a); float y = getValueFromTensorPos_counter(oPos.r, oPos.g, oPos.b, oPos.a); float condition = getValueFromTensorPos_condition(oPos.r, oPos.g, oPos.b, oPos.a); float o = 0.0; diff --git a/packages/paddlejs-core/src/opFactory/opDataBuilder.ts b/packages/paddlejs-core/src/opFactory/opDataBuilder.ts index 266ff5da..893cce61 100644 --- a/packages/paddlejs-core/src/opFactory/opDataBuilder.ts +++ b/packages/paddlejs-core/src/opFactory/opDataBuilder.ts @@ -183,14 +183,24 @@ export default class OpData { this.name = 'conv2d_elementwise_add'; } - if (this.name.indexOf('flatten2') > -1) { + else if (this.name.indexOf('flatten2') > -1) { this.name = 'reshape2'; } - if (this.name.indexOf('max_pool2d_with_index') > -1) { + else if (this.name.indexOf('max_pool2d_with_index') > -1) { this.name = 'pool2d_max'; } + else if (this.name.indexOf('instance_norm') > -1 || this.name.indexOf('sync_batch_norm') > -1) { + this.name = 'batchnorm'; + } + else if (this.name.indexOf('bilinear_interp_v2') > -1) { + this.name = 'bilinear_interp'; + } + else if (this.name.indexOf('leaky_relu') > -1) { + this.name = 'conv2d_elementwise_add'; + } + const tensorData: ModelVar[] = this.tensorData; // unique behavior const opKey = `${GLOBALS.backend}_${this.name}`;