Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] support masked_scatter #17525

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,31 @@ def _masked_fill(self, node: fx.Node) -> relax.Var:
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
return self.block_builder.emit(relax.op.where(mask, values, x))

def _masked_scatter(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
source = self.env[node.args[2]]
ndim = len(mask.struct_info.shape)
if ndim == 1:
index = self.block_builder.emit(relax.op.cumsum(mask, 0, dtype="int32"))
index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32")))
gathered_source = self.block_builder.emit(relax.op.take(source, index, axis=0))
else:
f_mask = self.block_builder.emit(relax.op.reshape(mask, [-1]))
index = self.block_builder.emit(relax.op.cumsum(f_mask, 0, dtype="int32"))
index = self.block_builder.emit(relax.op.subtract(index, relax.const(1, "int32")))
source_shape = [-1] + [
s for idx, s in enumerate(source.struct_info.shape) if idx >= ndim
]
f_source = self.block_builder.emit(relax.op.reshape(source, source_shape))
gathered_source = self.block_builder.emit(relax.op.take(f_source, index, axis=0))
gathered_source = self.block_builder.emit(
relax.op.reshape(gathered_source, x.struct_info.shape)
)
if ndim != len(x.struct_info.shape):
mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape))
return self.block_builder.emit(relax.op.where(mask, gathered_source, x))

def _ones(self, node: fx.Node) -> relax.Var:
import torch

Expand Down Expand Up @@ -695,6 +720,7 @@ def create_convert_map(
"index_select": self._index_select,
"masked_fill_": self._inplace_masked_fill,
"masked_fill": self._masked_fill,
"masked_scatter": self._masked_scatter,
"new_ones": self._new_ones,
"ones": self._ones,
"tensor": self._tensor,
Expand Down
4 changes: 3 additions & 1 deletion src/contrib/msc/core/ir/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,9 @@ const MSCPrim RelaxGraphBuilder::MatchOrCreatePrim(const PrimExpr& prim, const S
}

void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) {
AddNode(GetRef<relax::Constant>(op));
if (!expr_tensor_map_.count(GetRef<relax::Constant>(op))) {
AddNode(GetRef<relax::Constant>(op));
}
}

void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding,
Expand Down
13 changes: 10 additions & 3 deletions src/contrib/msc/core/transform/set_expr_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,16 @@ InferLayoutOutput ForwardInferLayoutTake(const Call& call,
return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs());
}
if (indices_layout->layout.defined()) {
size_t indices_size = indices_layout->layout.ndim();
LayoutDecision output_layout =
LayoutUtils::ExpandLayout(indices_layout, std::vector<size_t>{indices_size});
std::vector<size_t> expand_axes;
for (size_t i = indices_layout->layout.ndim(); i < output_shape.size(); i++) {
expand_axes.push_back(i);
}
LayoutDecision output_layout;
if (expand_axes.size() == 0) {
output_layout = indices_layout;
} else {
output_layout = LayoutUtils::ExpandLayout(indices_layout, expand_axes);
}
return InferLayoutOutput({input_layout, indices_layout}, {output_layout}, Attrs());
}
return InferLayoutOutput();
Expand Down
20 changes: 20 additions & 0 deletions src/contrib/msc/framework/torch/torch_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ class TorchConstantCodeGen : public TorchOpCode {
} else if (dtype == "float32") {
stack_.assign(module_ref(), node()->GetTypeAttr<float>("scalar"));
}
} else if (dtype == "bool") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
.inplace_start("torch.BoolTensor")
.call_arg(DocUtils::ToDocList(node()->OutputAt(0)->shape))
.inplace_end();
} else if (dtype == "int32") {
stack_.func_call("register_buffer", "", "self")
.call_arg(DocUtils::ToStr(ref_name))
Expand Down Expand Up @@ -658,6 +664,18 @@ class TorchStridedSliceCodeGen : public TorchOpCode {
}
};

class TorchTakeCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchTakeCodeGen)

protected:
void CodeGenForward() final {
if (node()->InputAt(1)->DTypeName() == "int32") {
stack_.func_call("to", IdxInput(1), IdxInput(1)).call_arg("torch.int64");
}
stack_.assign(IdxNode(), DocUtils::ToIndex(IdxInput(0), IdxInput(1)));
}
};

class TorchTriCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchTriCodeGen)

Expand Down Expand Up @@ -738,6 +756,7 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
map->emplace("subtract", std::make_shared<TorchSimpleCodeGen>("", "torch.subtract"));
map->emplace("tan", std::make_shared<TorchSimpleCodeGen>("", "torch.tan"));
map->emplace("tanh", std::make_shared<TorchSimpleCodeGen>("", "torch.tanh"));
map->emplace("where", std::make_shared<TorchSimpleCodeGen>("", "torch.where"));

// reduce ops
map->emplace("max", std::make_shared<TorchReduceAxesCodeGen>("", "torch.max"));
Expand Down Expand Up @@ -771,6 +790,7 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("", "torch.split"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", ""));
map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));

// create ops
map->emplace("constant", std::make_shared<TorchConstantCodeGen>("nn.Parameter", ""));
Expand Down
85 changes: 85 additions & 0 deletions tests/python/contrib/test_msc/test_graph_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,6 +2472,91 @@ def forward(self, data, index, src):
)


@pytest.mark.parametrize("dynamic", [True, False])
def test_masked_scatter(dynamic):
"""test graph builder for masked_scatter"""

dim = "dim" if dynamic else 5

class MaskedScatter1(Module):
def forward(self, data, mask, src):
return data.masked_scatter(mask, src)

class MaskedScatter2(Module):
def forward(self, data, mask, src):
return data.masked_scatter(mask, src)

expected1 = {
"inputs": [
{"name": "inp_0", "shape": [dim], "dtype": "float32", "layout": "A"},
{"name": "inp_1", "shape": [dim], "dtype": "bool", "layout": "A"},
{"name": "inp_2", "shape": [10], "dtype": "float32", "layout": "A"},
],
"outputs": [{"name": "where", "shape": [dim], "dtype": "float32", "layout": "A"}],
"nodes": {
"total": 8,
"input": 3,
"cumsum": 1,
"constant": 1,
"subtract": 1,
"take": 1,
"where": 1,
},
}
expected2 = {
"inputs": [
{
"name": "inp_0",
"shape": [2, dim],
"dtype": "float32",
"layout": "" if dynamic else "BA",
},
{
"name": "inp_1",
"shape": [2, dim],
"dtype": "bool",
"layout": "" if dynamic else "BA",
},
{
"name": "inp_2",
"shape": [3, dim],
"dtype": "float32",
"layout": "" if dynamic else "BA",
},
],
"outputs": [
{
"name": "where",
"shape": [2, dim],
"dtype": "float32",
"layout": "" if dynamic else "BA",
}
],
"nodes": {
"total": 11,
"input": 3,
"reshape": 3,
"cumsum": 1,
"constant": 1,
"subtract": 1,
"take": 1,
"where": 1,
},
}
if dynamic:
expected1["prims"] = {"total": 1, "shape": 1}
expected2["prims"] = {"total": 5, "shape": 1, "Int": 2, "Mul": 2}

verify_model(
MaskedScatter1(), [([dim], "float32"), ([dim], "bool"), ([10], "float32")], expected1
)
verify_model(
MaskedScatter2(),
[([2, dim], "float32"), ([2, dim], "bool"), ([3, dim], "float32")],
expected2,
)


def test_put():
"""test graph builder for index_put"""

Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_msc/test_translate_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,29 @@ def forward(self, data, index, src):
verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")])


def test_masked_scatter():
"""test relax translator for masked_scatter"""

class MaskedScatter1(Module):
def __init__(self):
super().__init__()
self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH)

def forward(self, data, src):
return data.masked_scatter(self.mask, src)

class MaskedScatter2(Module):
def __init__(self):
super().__init__()
self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH)

def forward(self, data, src):
return data.masked_scatter(self.mask, src)

verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")])
verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")])


def test_put():
"""test relax translator for index_put"""

Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_msc/test_translate_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,29 @@ def forward(self, data, index, src):
)


def test_masked_scatter():
"""test torch translator for masked_scatter"""

class MaskedScatter1(Module):
def __init__(self):
super().__init__()
self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH)

def forward(self, data, src):
return data.masked_scatter(self.mask, src)

class MaskedScatter2(Module):
def __init__(self):
super().__init__()
self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH)

def forward(self, data, src):
return data.masked_scatter(self.mask, src)

verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")], True)
verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")], True)


def test_put():
"""test torch translator for index_put"""

Expand Down
61 changes: 61 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4023,5 +4023,66 @@ def main(
verify_model(Scatter(), input_info, {}, expected)


def test_masked_scatter():
class MaskedScatter1(Module):
def forward(self, data, mask, src):
return data.masked_scatter(mask, src)

class MaskedScatter2(Module):
def forward(self, data, mask, src):
return data.masked_scatter(mask, src)

@tvm.script.ir_module
class expected1:
@R.function
def main(
inp_0: R.Tensor((5,), dtype="float32"),
inp_1: R.Tensor((5,), dtype="bool"),
inp_2: R.Tensor((10,), dtype="float32"),
) -> R.Tensor((5,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((5,), dtype="int32") = R.cumsum(
inp_1, axis=0, dtype="int32", exclusive=False
)
lv1: R.Tensor((5,), dtype="int32") = R.subtract(lv, R.const(1, "int32"))
lv2: R.Tensor((5,), dtype="float32") = R.take(inp_2, lv1, axis=0)
lv3: R.Tensor((5,), dtype="float32") = R.where(inp_1, lv2, inp_0)
gv: R.Tensor((5,), dtype="float32") = lv3
R.output(gv)
return gv

@tvm.script.ir_module
class expected2:
@R.function
def main(
inp_0: R.Tensor((2, 5), dtype="float32"),
inp_1: R.Tensor((2, 5), dtype="bool"),
inp_2: R.Tensor((3, 5), dtype="float32"),
) -> R.Tensor((2, 5), dtype="float32"):
with R.dataflow():
lv: R.Tensor((10,), dtype="bool") = R.reshape(inp_1, R.shape([10]))
lv1: R.Tensor((10,), dtype="int32") = R.cumsum(
lv, axis=0, dtype="int32", exclusive=False
)
lv2: R.Tensor((10,), dtype="int32") = R.subtract(lv1, R.const(1, "int32"))
lv3: R.Tensor((15,), dtype="float32") = R.reshape(inp_2, R.shape([15]))
lv4: R.Tensor((10,), dtype="float32") = R.take(lv3, lv2, axis=0)
lv5: R.Tensor((2, 5), dtype="float32") = R.reshape(lv4, R.shape([2, 5]))
lv6: R.Tensor((2, 5), dtype="float32") = R.where(inp_1, lv5, inp_0)
gv: R.Tensor((2, 5), dtype="float32") = lv6
R.output(gv)
return gv

verify_model(
MaskedScatter1(), [([5], "float32"), ([5], "bool"), ([10], "float32")], {}, expected1
)
verify_model(
MaskedScatter2(),
[([2, 5], "float32"), ([2, 5], "bool"), ([3, 5], "float32")],
{},
expected2,
)


if __name__ == "__main__":
tvm.testing.main()
Loading