Skip to content

Commit

Permalink
[TorchToLinalg]Lower torch.gcd to linalg and scf
Browse files Browse the repository at this point in the history
Add verify() method to check if tensors are of
integer type. Also check if tensors are of same shape,
or if the second tensor is a single element tensor.

Add e2e tests. Put them into onnx and stablehlo
xfailed sets.
  • Loading branch information
bratislavSyrmia committed Sep 25, 2024
1 parent 6773288 commit 7673a8f
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 17 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13148,6 +13148,31 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [
}];
}

def Torch_AtenGcdOp : Torch_Op<"aten.gcd", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::gcd : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenGcdOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenGcdOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
110 changes: 110 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
Expand Down Expand Up @@ -213,6 +215,112 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
};
} // namespace

namespace {
class ConvertAtenGcdOp : public OpConversionPattern<torch::Torch::AtenGcdOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto self = adaptor.getSelf(); // tensor A
auto other = adaptor.getOther(); // tensor B of the same size
auto loc = op.getLoc();

TensorType resultType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

auto gcdPayloadBody = [&](OpBuilder &b, Location loc,
ValueRange payloadArgs) {
auto A = payloadArgs[0];
A = b.create<mlir::math::AbsIOp>(loc, A);
auto B = payloadArgs[1];
B = b.create<mlir::math::AbsIOp>(loc, B);
auto two = b.create<mlir::arith::ConstantIntOp>(loc, 2, A.getType());
auto one = b.create<mlir::arith::ConstantIntOp>(loc, 1, A.getType());
auto zero = b.create<mlir::arith::ConstantIntOp>(loc, 0, A.getType());

auto trailingZeroConditionBlock = [&](mlir::OpBuilder &b,
mlir::Location loc,
mlir::ValueRange whileArgs) {
auto current = whileArgs[0];
auto counter = whileArgs[1];
auto currentAndOne = b.create<mlir::arith::AndIOp>(loc, current, one);
auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::sgt, currentAndOne, one);
b.create<mlir::scf::ConditionOp>(loc, cmp,
ValueRange{current, counter});
};
auto trailingZerosBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
auto current = args[0];
auto counter = args[1];
auto divided = b.create<mlir::arith::DivUIOp>(loc, current, two);
auto newCounter = b.create<mlir::arith::AddIOp>(loc, counter, one);
b.create<mlir::scf::YieldOp>(
loc, ValueRange{divided.getResult(), newCounter.getResult()});
};

auto AtrailingZerosOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{A.getType(), zero.getType()}, ValueRange{A, zero},
trailingZeroConditionBlock, trailingZerosBodyBlock);
auto BtrailingZerosOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{B.getType(), zero.getType()}, ValueRange{B, zero},
trailingZeroConditionBlock, trailingZerosBodyBlock);

Value AtrailingZerosCount = AtrailingZerosOp.getResult(0);
Value BtrailingZerosCount = BtrailingZerosOp.getResult(0);
auto smalerZerosCount = b.create<mlir::arith::MinSIOp>(
loc, AtrailingZerosCount, BtrailingZerosCount);
auto shiftedA = b.create<mlir::arith::ShRSIOp>(loc, A, smalerZerosCount);
auto shiftedB = b.create<mlir::arith::ShRSIOp>(loc, B, smalerZerosCount);

auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
Value min = b.create<mlir::arith::MinSIOp>(loc, args[0], args[1]);
Value max =
b.create<mlir::arith::MaxSIOp>(loc, payloadArgs[0], payloadArgs[1]);

auto cmp = b.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::ne, min, zero);
b.create<mlir::scf::ConditionOp>(loc, cmp, ValueRange{min, max});
};
auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
Value min = args[0];
Value max = args[1];
max = b.create<mlir::arith::SubIOp>(loc, max, min);

auto maxTrailingZerosOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{B.getType(), zero.getType()}, ValueRange{max, zero},
trailingZeroConditionBlock, trailingZerosBodyBlock);
Value maxTrailingZerosCount = maxTrailingZerosOp.getResult(0);
max = b.create<mlir::arith::ShRSIOp>(loc, max, maxTrailingZerosCount);
b.create<mlir::scf::YieldOp>(loc, ValueRange{min, max});
};

auto findGcdWhileOp = b.create<mlir::scf::WhileOp>(
loc, TypeRange{shiftedA.getType(), shiftedB.getType()},
ValueRange{shiftedA, shiftedB}, findGcdConditionBlock,
findGcdBodyBlock);

Value gcdResult = findGcdWhileOp.getResult(1);
gcdResult =
b.create<mlir::arith::ShLIOp>(loc, gcdResult, smalerZerosCount);

b.create<linalg::YieldOp>(loc, gcdResult);
};

other = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{self, other},
cast<TensorType>(self.getType()).getElementType(), gcdPayloadBody);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, other);
return success();
}
};
} // namespace

namespace {
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
public:
Expand Down Expand Up @@ -1400,4 +1508,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
target.addIllegalOp<AtenGcdOp>();
patterns.add<ConvertAtenGcdOp>(typeConverter, context);
}
34 changes: 34 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5524,3 +5524,37 @@ LogicalResult AtenRot90Op::verify() {

return success();
}

LogicalResult AtenGcdOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());
auto otherType = cast<BaseTensorType>(getOther().getType());

if (!selfType.hasDtype() || !selfType.hasSizes() || !otherType.hasDtype() ||
!otherType.hasSizes())
return success();

auto selfShape = selfType.getSizes();
auto otherShape = selfType.getSizes();
int64_t selfRank = selfShape.size();
int64_t otherRank = otherShape.size();
auto selfDtype = selfType.getDtype();

if (!isa<mlir::IntegerType>(selfDtype))
return emitOpError("expected an integer type for input tensor, but got ")
<< selfDtype;

if (otherRank == 1 && otherShape[0] == 1)
return success();

if (selfRank != otherRank)
return emitOpError("Tensors must be of same rank or second tensor must be "
"a single element tensor");

for (int i = 0; i < selfRank; i++) {
if (selfShape[i] != otherShape[i])
return emitOpError("Dimensions od tensors font match in dim ") << i;
}

return success();
}
81 changes: 66 additions & 15 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %8 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %6 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n"
" %false = torch.constant.bool false\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -11238,21 +11304,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n"
" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" return %1 : !torch.bool\n"
" }\n"
" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list<int> {\n"
" %int4 = torch.constant.int 4\n"
" %int3 = torch.constant.int 3\n"
" %int2 = torch.constant.int 2\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %int11 = torch.constant.int 11\n"
" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,9 @@
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"GCDBatchedModule_I32",
"GCDDynamicModule_I32",
"GCDModule_I32",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3126,6 +3129,9 @@
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"UnfoldModule_basic",
"GCDBatchedModule_I32",
"GCDDynamicModule_I32",
"GCDModule_I32",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]:
shape = upstream_shape_functions.zero_dim_tensor(A)
return shape, shape

def aten〇gcd〡shape(self: List[int], other: List[int]) -> List[int]:
assert self == other or (len(other) == 1 and other[0]==0), "Shapes must be the same or 'other' must be a single element tensor."
return self

def aten〇gcd〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
assert is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype), "aten.gcd works only with integer types"
return self_dtype


def aten〇detach〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)"
)
emit("aten::gcd : (Tensor, Tensor) -> (Tensor)", has_verifier=True)

# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")
Expand Down
49 changes: 49 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6845,3 +6845,52 @@ def forward(self):
@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule())
def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


class GCDModule(torch.nn.Module):
@export
@annotate_args([None, [(4, 4), torch.int32, True], [(4, 4), torch.int32, True]])
def forward(self, A, B):
return torch.gcd(A, B)


@register_test_case(module_factory=lambda: GCDModule())
def GCDModule_I32(module, tu: TestUtils):
A = tu.rand(4, 4).to(dtype=torch.int32)
B = tu.rand(4, 4).to(dtype=torch.int32)
module.forward(A, B)


class GCDBatchedModule(torch.nn.Module):
@export
@annotate_args(
[None, [(4, 4, 4), torch.int32, True], [(4, 4, 4), torch.int32, True]]
)
def forward(self, A, B):
return torch.gcd(A, B)


@register_test_case(module_factory=lambda: GCDBatchedModule())
def GCDBatchedModule_I32(module, tu: TestUtils):
A = tu.rand(4, 4, 4).to(dtype=torch.int32)
B = tu.rand(4, 4, 4).to(dtype=torch.int32)
module.forward(A, B)


class GCDDynamicModule(torch.nn.Module):
@export
@annotate_args(
[None, [(-1, -1, -1), torch.int32, True], [(-1, -1, -1), torch.int32, True]]
)
def forward(self, A, B):
return torch.gcd(A, B)


@register_test_case(module_factory=lambda: GCDDynamicModule())
def GCDDynamicModule_I32(module, tu: TestUtils):
A = tu.rand(3, 4, 4).to(dtype=torch.int32)
B = tu.rand(3, 4, 4).to(dtype=torch.int32)
module.forward(A, B)
2 changes: 0 additions & 2 deletions projects/pt1/tools/e2e_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@ cd "$src_dir"

# Ensure PYTHONPATH is set for export to child processes, even if empty.
export PYTHONPATH=${PYTHONPATH-}
source $project_dir/.env

python -m e2e_testing.main "$@"

0 comments on commit 7673a8f

Please sign in to comment.