From 7ea4c2f38375e2f317e9a6135e8a72c16d726a0d Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 24 Oct 2024 14:28:44 +0800 Subject: [PATCH] [js/webgpu] Add GatherND --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts | 179 ++++++++++++++++++ js/web/test/data/ops/gathernd.jsonc | 147 ++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + .../providers/js/js_execution_provider.cc | 8 + .../core/providers/js/operators/gather_nd.cc | 43 +++++ .../core/providers/js/operators/gather_nd.h | 24 +++ 8 files changed, 405 insertions(+) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts create mode 100644 js/web/test/data/ops/gathernd.jsonc create mode 100644 onnxruntime/core/providers/js/operators/gather_nd.cc create mode 100644 onnxruntime/core/providers/js/operators/gather_nd.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index f63cf17aa4df3..5c8748d75c2bc 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -50,6 +50,7 @@ Do not modify directly.* | Gather | ai.onnx(1-10,11-12,13+) | | | GatherBlockQuantized | com.microsoft(1+) | | | GatherElements | ai.onnx(11-12,13+) | | +| GatherND | ai.onnx(11,12,13+) | | | Gelu | ai.onnx(20+); com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 28af5d461abe0..6c7afbc7365bb 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum'; import { expand } from './ops/expand'; import { fastGelu } from './ops/fast-gelu'; import { gather, parseGatherAttributes } from './ops/gather'; +import { gatherND, parseGatherNDAttributes } from './ops/gather-nd'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; @@ -100,6 +101,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]], + ['GatherND', [gatherND, parseGatherNDAttributes]], ['Gelu', [unaryOps.gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts new file mode 100644 index 0000000000000..09b1d67713cf3 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramUniform } from '../types'; + +import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; + +export interface GatherNDAttributes extends AttributeWithCacheKey { + readonly batchDims: number; +} + +const computeSliceOffsets = ( + context: ComputeContext, + indicesData: TensorView, + sizesFromSliceDimsData: number[], + batchDims: number, + inputDims: readonly number[], + numSlices: number, + numSlicesPerBatch: number, + inputBatchStride: number, + numSliceDims: number, +) => { + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: numSlices }, + { type: DataType.uint32, data: batchDims }, + { type: DataType.uint32, data: inputDims }, + { type: DataType.uint32, data: sizesFromSliceDimsData }, + { type: DataType.uint32, data: numSlicesPerBatch }, + { type: DataType.uint32, data: inputBatchStride }, + { type: DataType.uint32, data: numSliceDims }, + ]; + + const outputShape = [numSlices]; + programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length); + const output = outputVariable('input_slice_offsets_data', indicesData.dataType, 1, 1); + const variables = [indices, output]; + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'batch_dims', type: 'u32' }, + { name: 'input_dims', type: 'u32', length: inputDims.length }, + { name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length }, + { name: 'num_slices_per_batch', type: 'u32' }, + { name: 'input_batch_stride', type: 'u32' }, + { name: 'num_slice_dims', type: 'u32' }, + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let batch_idx = global_idx / uniforms.num_slices_per_batch; + let base_offset = batch_idx * uniforms.input_batch_stride; + + let slice_indices_base_offset = global_idx * uniforms.num_slice_dims; + var relative_slice_offset = 0; + for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) { + var index = i32(indices_data[dim_idx + slice_indices_base_offset].x); + let input_dim_idx = uniforms.batch_dims + dim_idx; + if (index < 0) { + ${ + inputDims.length === 1 + ? 'index += i32(uniforms.input_dims);' + : 'index += i32(uniforms.input_dims[input_dim_idx]);' + } + } + ${ + sizesFromSliceDimsData.length === 1 + ? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);' + : 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);' + } + } + + input_slice_offsets_data[global_idx].x = base_offset + u32(relative_slice_offset); + }`; + }; + + return context.compute( + { + name: 'computeSliceOffsets', + shaderCache: { hint: `${inputDims.length}_${sizesFromSliceDimsData.length}`, inputDependencies: ['rank'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }], + dispatchGroup: { x: Math.ceil(numSlices / 64) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [indicesData], outputs: [-1] }, + )[0]; +}; + +export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes) => { + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const inputType = inputs[0].dataType; + const indicesShape = inputs[1].dims; + const numSliceDims = indicesShape[indicesShape.length - 1]; + const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1); + const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims); + const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims); + const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims); + const numSlicesPerBatch = numSlices / numBatches; + const sizesFromSliceDims = new Array(numSliceDims); + let runningProduct = sliceSize; + for (let i = 0; i < numSliceDims; ++i) { + sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct; + runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i]; + } + + const inputSliceOffsets = computeSliceOffsets( + context, + inputs[1], + sizesFromSliceDims, + attributes.batchDims, + inputShape, + numSlices, + numSlicesPerBatch, + inputBatchStride, + numSliceDims, + ); + + const lastIndicesDimension = attributes.batchDims + numSliceDims; + if (lastIndicesDimension > inputShape.length) { + throw new Error('last dimension of indices must not be larger than rank of input tensor'); + } + + const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension)); + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: sliceSize }, + ...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length); + const indices = inputVariable('slice_offsets', inputSliceOffsets.dataType, inputSliceOffsets.dims.length); + + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + return ` + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('slice_size', 'u32') + .declareVariables(input, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let slice_offset = slice_offsets[global_idx / uniforms.slice_size].x; + output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size]; + }`; + }; + context.compute( + { + name: 'GatherND', + shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }, + { inputs: [inputs[0], inputSliceOffsets] }, + ); +}; + +export const parseGatherNDAttributes = (attributes: Record): GatherNDAttributes => { + const batchDims = attributes.batch_dims as number; + return { + batchDims, + cacheKey: ``, + }; +}; diff --git a/js/web/test/data/ops/gathernd.jsonc b/js/web/test/data/ops/gathernd.jsonc new file mode 100644 index 0000000000000..209c7d1f74087 --- /dev/null +++ b/js/web/test/data/ops/gathernd.jsonc @@ -0,0 +1,147 @@ +[ + { + "name": "GatherND int32", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100, 101, 102, 777, 778, 779, 1000, 1001, 1002], + "dims": [9], + "type": "int32" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100, 778, 1002], + "dims": [3], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherND float32", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9], + "dims": [9], + "type": "float32" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100.0999984741211, 778.5, 1002.9000244140625], + "dims": [3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherND int32 [2 2 2], batch_dims", + "operator": "GatherND", + "attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 2, 2], + "type": "int32" + }, + { + "data": [1, 0], + "dims": [2, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [2, 3, 4, 5], + "dims": [2, 2], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherND float16", + "operator": "GatherND", + "attributes": [], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9], + "dims": [9], + "type": "float16" + }, + { + "data": [0, 4, 8], + "dims": [3, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [100.0999984741211, 778.5, 1002.9000244140625], + "dims": [3], + "type": "float16" + } + ] + } + ] + }, + { + "name": "GatherND uint32 [2 2 2], batch_dims", + "operator": "GatherND", + "attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }], + "cases": [ + { + "name": "data[4] indices[]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [2, 2, 2], + "type": "uint32" + }, + { + "data": [1, 0], + "dims": [2, 1], + "type": "int64" + } + ], + "outputs": [ + { + "data": [2, 3, 4, 5], + "dims": [2, 2], + "type": "uint32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 45fb771ee13bb..03dd84fae89c5 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1365,6 +1365,7 @@ "gather.jsonc", "gather-block-quantized.jsonc", "gather-elements.jsonc", + "gathernd.jsonc", "gemm.jsonc", "global-average-pool.jsonc", "greater.jsonc", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c0d62bf47a0dd..3991885d92ce3 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -341,6 +341,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, GatherND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, GatherND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice); @@ -667,6 +671,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gather_nd.cc b/onnxruntime/core/providers/js/operators/gather_nd.cc new file mode 100644 index 0000000000000..70a7c83c34527 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_nd.cc @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "gather_nd.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("indices", BuildKernelDefConstraintsFromTypeList>()), + GatherND); + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kOnnxDomain, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("indices", BuildKernelDefConstraintsFromTypeList>()), + GatherND); + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("indices", BuildKernelDefConstraintsFromTypeList>()), + GatherND); + + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gather_nd.h b/onnxruntime/core/providers/js/operators/gather_nd.h new file mode 100644 index 0000000000000..cdf7a52630dad --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_nd.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class GatherND : public JsKernel { + public: + GatherND(const OpKernelInfo& info) : JsKernel(info) { + int64_t batchDims = info.GetAttrOrDefault("batch_dims", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GatherND, ({ + "batch_dims" : Number($1), + }), + static_cast(batchDims)); + } +}; + +} // namespace js +} // namespace onnxruntime