Skip to content

Commit

Permalink
[js/webgpu] Add GatherND
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 15, 2024
1 parent 632a36a commit 7ea4c2f
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 0 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Do not modify directly.*
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherBlockQuantized | com.microsoft(1+) | |
| GatherElements | ai.onnx(11-12,13+) | |
| GatherND | ai.onnx(11,12,13+) | |
| Gelu | ai.onnx(20+); com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum';
import { expand } from './ops/expand';
import { fastGelu } from './ops/fast-gelu';
import { gather, parseGatherAttributes } from './ops/gather';
import { gatherND, parseGatherNDAttributes } from './ops/gather-nd';
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
import { gemm, parseGemmAttributes } from './ops/gemm';
Expand Down Expand Up @@ -100,6 +101,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]],
['GatherND', [gatherND, parseGatherNDAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
Expand Down
179 changes: 179 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramUniform } from '../types';

import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';

export interface GatherNDAttributes extends AttributeWithCacheKey {
readonly batchDims: number;
}

const computeSliceOffsets = (
context: ComputeContext,
indicesData: TensorView,
sizesFromSliceDimsData: number[],
batchDims: number,
inputDims: readonly number[],
numSlices: number,
numSlicesPerBatch: number,
inputBatchStride: number,
numSliceDims: number,
) => {
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: numSlices },
{ type: DataType.uint32, data: batchDims },
{ type: DataType.uint32, data: inputDims },
{ type: DataType.uint32, data: sizesFromSliceDimsData },
{ type: DataType.uint32, data: numSlicesPerBatch },
{ type: DataType.uint32, data: inputBatchStride },
{ type: DataType.uint32, data: numSliceDims },
];

const outputShape = [numSlices];
programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length);
const output = outputVariable('input_slice_offsets_data', indicesData.dataType, 1, 1);
const variables = [indices, output];
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'batch_dims', type: 'u32' },
{ name: 'input_dims', type: 'u32', length: inputDims.length },
{ name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length },
{ name: 'num_slices_per_batch', type: 'u32' },
{ name: 'input_batch_stride', type: 'u32' },
{ name: 'num_slice_dims', type: 'u32' },
];
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let batch_idx = global_idx / uniforms.num_slices_per_batch;
let base_offset = batch_idx * uniforms.input_batch_stride;
let slice_indices_base_offset = global_idx * uniforms.num_slice_dims;
var relative_slice_offset = 0;
for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) {
var index = i32(indices_data[dim_idx + slice_indices_base_offset].x);
let input_dim_idx = uniforms.batch_dims + dim_idx;
if (index < 0) {
${
inputDims.length === 1
? 'index += i32(uniforms.input_dims);'
: 'index += i32(uniforms.input_dims[input_dim_idx]);'
}
}
${
sizesFromSliceDimsData.length === 1
? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);'
: 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);'
}
}
input_slice_offsets_data[global_idx].x = base_offset + u32(relative_slice_offset);
}`;
};

return context.compute(
{
name: 'computeSliceOffsets',
shaderCache: { hint: `${inputDims.length}_${sizesFromSliceDimsData.length}`, inputDependencies: ['rank'] },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }],
dispatchGroup: { x: Math.ceil(numSlices / 64) },
programUniforms,
}),
getShaderSource,
},
{ inputs: [indicesData], outputs: [-1] },
)[0];
};

export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes) => {
const inputs = context.inputs;
const inputShape = inputs[0].dims;
const inputType = inputs[0].dataType;
const indicesShape = inputs[1].dims;
const numSliceDims = indicesShape[indicesShape.length - 1];
const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1);
const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims);
const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims);
const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims);
const numSlicesPerBatch = numSlices / numBatches;
const sizesFromSliceDims = new Array(numSliceDims);
let runningProduct = sliceSize;
for (let i = 0; i < numSliceDims; ++i) {
sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct;
runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i];
}

const inputSliceOffsets = computeSliceOffsets(
context,
inputs[1],
sizesFromSliceDims,
attributes.batchDims,
inputShape,
numSlices,
numSlicesPerBatch,
inputBatchStride,
numSliceDims,
);

const lastIndicesDimension = attributes.batchDims + numSliceDims;
if (lastIndicesDimension > inputShape.length) {
throw new Error('last dimension of indices must not be larger than rank of input tensor');
}

const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension));
const outputSize = ShapeUtil.size(outputShape);

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: sliceSize },
...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape),
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length);
const indices = inputVariable('slice_offsets', inputSliceOffsets.dataType, inputSliceOffsets.dims.length);

const output = outputVariable('output', inputs[0].dataType, outputShape.length);
return `
${shaderHelper
.registerUniform('output_size', 'u32')
.registerUniform('slice_size', 'u32')
.declareVariables(input, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let slice_offset = slice_offsets[global_idx / uniforms.slice_size].x;
output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size];
}`;
};
context.compute(
{
name: 'GatherND',
shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
},
{ inputs: [inputs[0], inputSliceOffsets] },
);
};

export const parseGatherNDAttributes = (attributes: Record<string, unknown>): GatherNDAttributes => {
const batchDims = attributes.batch_dims as number;
return {
batchDims,
cacheKey: ``,
};
};
147 changes: 147 additions & 0 deletions js/web/test/data/ops/gathernd.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
[
{
"name": "GatherND int32",
"operator": "GatherND",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [100, 101, 102, 777, 778, 779, 1000, 1001, 1002],
"dims": [9],
"type": "int32"
},
{
"data": [0, 4, 8],
"dims": [3, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [100, 778, 1002],
"dims": [3],
"type": "int32"
}
]
}
]
},
{
"name": "GatherND float32",
"operator": "GatherND",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9],
"dims": [9],
"type": "float32"
},
{
"data": [0, 4, 8],
"dims": [3, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [100.0999984741211, 778.5, 1002.9000244140625],
"dims": [3],
"type": "float32"
}
]
}
]
},
{
"name": "GatherND int32 [2 2 2], batch_dims",
"operator": "GatherND",
"attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [2, 2, 2],
"type": "int32"
},
{
"data": [1, 0],
"dims": [2, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [2, 3, 4, 5],
"dims": [2, 2],
"type": "int32"
}
]
}
]
},
{
"name": "GatherND float16",
"operator": "GatherND",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9],
"dims": [9],
"type": "float16"
},
{
"data": [0, 4, 8],
"dims": [3, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [100.0999984741211, 778.5, 1002.9000244140625],
"dims": [3],
"type": "float16"
}
]
}
]
},
{
"name": "GatherND uint32 [2 2 2], batch_dims",
"operator": "GatherND",
"attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [2, 2, 2],
"type": "uint32"
},
{
"data": [1, 0],
"dims": [2, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [2, 3, 4, 5],
"dims": [2, 2],
"type": "uint32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@
"gather.jsonc",
"gather-block-quantized.jsonc",
"gather-elements.jsonc",
"gathernd.jsonc",
"gemm.jsonc",
"global-average-pool.jsonc",
"greater.jsonc",
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
Expand Down Expand Up @@ -667,6 +671,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements)>,

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
Expand Down
Loading

0 comments on commit 7ea4c2f

Please sign in to comment.