From 7673a8ff2851d4ffc3a79584eddb0ddef741e704 Mon Sep 17 00:00:00 2001 From: Bratislav Filipovic Date: Thu, 19 Sep 2024 14:55:21 +0200 Subject: [PATCH] [TorchToLinalg]Lower torch.gcd to linalg and scf Add verify() method to check if tensors are of integer type. Also check if tensors are of same shape, or if the second tensor is a single element tensor. Add e2e tests. Put them into onnx and stablehlo xfailed sets. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++ lib/Conversion/TorchToLinalg/Linear.cpp | 110 ++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 34 ++++++ .../Transforms/AbstractInterpLibrary.cpp | 81 ++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 6 + .../build_tools/abstract_interp_lib_gen.py | 11 ++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 49 ++++++++ projects/pt1/tools/e2e_test.sh | 2 - 9 files changed, 302 insertions(+), 17 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0b1a8b25720e..edadc94ddc19 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13148,6 +13148,31 @@ def Torch_AtenStftOp : Torch_Op<"aten.stft", [ }]; } +def Torch_AtenGcdOp : Torch_Op<"aten.gcd", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::gcd : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGcdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGcdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 52765411bd73..5bedc826f2be 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -13,6 +13,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -213,6 +215,112 @@ class ConvertAtenMmOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenGcdOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(torch::Torch::AtenGcdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto self = adaptor.getSelf(); // tensor A + auto other = adaptor.getOther(); // tensor B of the same size + auto loc = op.getLoc(); + + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + + auto gcdPayloadBody = [&](OpBuilder &b, Location loc, + ValueRange payloadArgs) { + auto A = payloadArgs[0]; + A = b.create(loc, A); + auto B = payloadArgs[1]; + B = b.create(loc, B); + auto two = b.create(loc, 2, A.getType()); + auto one = b.create(loc, 1, A.getType()); + auto zero = b.create(loc, 0, A.getType()); + + auto trailingZeroConditionBlock = [&](mlir::OpBuilder &b, + mlir::Location loc, + mlir::ValueRange whileArgs) { + auto current = whileArgs[0]; + auto counter = whileArgs[1]; + auto currentAndOne = b.create(loc, current, one); + auto cmp = b.create( + loc, mlir::arith::CmpIPredicate::sgt, currentAndOne, one); + b.create(loc, cmp, + ValueRange{current, counter}); + }; + auto trailingZerosBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + auto current = args[0]; + auto counter = args[1]; + auto divided = b.create(loc, current, two); + auto newCounter = b.create(loc, counter, one); + b.create( + loc, ValueRange{divided.getResult(), newCounter.getResult()}); + }; + + auto AtrailingZerosOp = b.create( + loc, TypeRange{A.getType(), zero.getType()}, ValueRange{A, zero}, + trailingZeroConditionBlock, trailingZerosBodyBlock); + auto BtrailingZerosOp = b.create( + loc, TypeRange{B.getType(), zero.getType()}, ValueRange{B, zero}, + trailingZeroConditionBlock, trailingZerosBodyBlock); + + Value AtrailingZerosCount = AtrailingZerosOp.getResult(0); + Value BtrailingZerosCount = BtrailingZerosOp.getResult(0); + auto smalerZerosCount = b.create( + loc, AtrailingZerosCount, BtrailingZerosCount); + auto shiftedA = b.create(loc, A, smalerZerosCount); + auto shiftedB = b.create(loc, B, smalerZerosCount); + + auto findGcdConditionBlock = [&](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + Value min = b.create(loc, args[0], args[1]); + Value max = + b.create(loc, payloadArgs[0], payloadArgs[1]); + + auto cmp = b.create( + loc, mlir::arith::CmpIPredicate::ne, min, zero); + b.create(loc, cmp, ValueRange{min, max}); + }; + auto findGcdBodyBlock = [&](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + Value min = args[0]; + Value max = args[1]; + max = b.create(loc, max, min); + + auto maxTrailingZerosOp = b.create( + loc, TypeRange{B.getType(), zero.getType()}, ValueRange{max, zero}, + trailingZeroConditionBlock, trailingZerosBodyBlock); + Value maxTrailingZerosCount = maxTrailingZerosOp.getResult(0); + max = b.create(loc, max, maxTrailingZerosCount); + b.create(loc, ValueRange{min, max}); + }; + + auto findGcdWhileOp = b.create( + loc, TypeRange{shiftedA.getType(), shiftedB.getType()}, + ValueRange{shiftedA, shiftedB}, findGcdConditionBlock, + findGcdBodyBlock); + + Value gcdResult = findGcdWhileOp.getResult(1); + gcdResult = + b.create(loc, gcdResult, smalerZerosCount); + + b.create(loc, gcdResult); + }; + + other = torch_to_linalg::createElementwiseLinalgGeneric( + rewriter, loc, ValueRange{self, other}, + cast(self.getType()).getElementType(), gcdPayloadBody); + + rewriter.replaceOpWithNewOp(op, resultType, other); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenFlipOp : public OpConversionPattern { public: @@ -1400,4 +1508,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bed228671de1..d45e7cff8840 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5524,3 +5524,37 @@ LogicalResult AtenRot90Op::verify() { return success(); } + +LogicalResult AtenGcdOp::verify() { + + auto selfType = cast(getSelf().getType()); + auto otherType = cast(getOther().getType()); + + if (!selfType.hasDtype() || !selfType.hasSizes() || !otherType.hasDtype() || + !otherType.hasSizes()) + return success(); + + auto selfShape = selfType.getSizes(); + auto otherShape = selfType.getSizes(); + int64_t selfRank = selfShape.size(); + int64_t otherRank = otherShape.size(); + auto selfDtype = selfType.getDtype(); + + if (!isa(selfDtype)) + return emitOpError("expected an integer type for input tensor, but got ") + << selfDtype; + + if (otherRank == 1 && otherShape[0] == 1) + return success(); + + if (selfRank != otherRank) + return emitOpError("Tensors must be of same rank or second tensor must be " + "a single element tensor"); + + for (int i = 0; i < selfRank; i++) { + if (selfShape[i] != otherShape[i]) + return emitOpError("Dimensions od tensors font match in dim ") << i; + } + + return success(); +} diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 59cf69393ded..5c38a9d74550 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6639,6 +6639,72 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %8 : !torch.tuple, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.gcd\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Shapes must be the same or 'other' must be a single element tensor.\"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int_list %arg0, %arg1 : !torch.list, !torch.list -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.eq.int %5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %6 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" }\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gcd\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: aten.gcd works only with integer types\"\n" +" %false = torch.constant.bool false\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11238,21 +11304,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %3 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" -" }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bdb4d7f47e7d..1f2d586df5bc 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -923,6 +923,9 @@ "SplitTensorNegativeDimModule_basic", "SplitWithSizesListUnpackModule_basic", "SplitWithSizes_Module_basic", + "GCDBatchedModule_I32", + "GCDDynamicModule_I32", + "GCDModule_I32", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3126,6 +3129,9 @@ "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "UnfoldModule_basic", + "GCDBatchedModule_I32", + "GCDDynamicModule_I32", + "GCDModule_I32", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc49757ee9d3..5cc558f0f415 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -265,6 +265,17 @@ def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]: shape = upstream_shape_functions.zero_dim_tensor(A) return shape, shape +def aten〇gcd〡shape(self: List[int], other: List[int]) -> List[int]: + assert self == other or (len(other) == 1 and other[0]==0), "Shapes must be the same or 'other' must be a single element tensor." + return self + +def aten〇gcd〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + assert is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype), "aten.gcd works only with integer types" + return self_dtype + + def aten〇detach〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5f53e17b9d17..f5a3d660858f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -964,6 +964,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?) -> (Tensor)" ) + emit("aten::gcd : (Tensor, Tensor) -> (Tensor)", has_verifier=True) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 9b4dbe659b6f..dde3c30744de 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -6845,3 +6845,52 @@ def forward(self): @register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule()) def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class GCDModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.int32, True], [(4, 4), torch.int32, True]]) + def forward(self, A, B): + return torch.gcd(A, B) + + +@register_test_case(module_factory=lambda: GCDModule()) +def GCDModule_I32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.int32) + B = tu.rand(4, 4).to(dtype=torch.int32) + module.forward(A, B) + + +class GCDBatchedModule(torch.nn.Module): + @export + @annotate_args( + [None, [(4, 4, 4), torch.int32, True], [(4, 4, 4), torch.int32, True]] + ) + def forward(self, A, B): + return torch.gcd(A, B) + + +@register_test_case(module_factory=lambda: GCDBatchedModule()) +def GCDBatchedModule_I32(module, tu: TestUtils): + A = tu.rand(4, 4, 4).to(dtype=torch.int32) + B = tu.rand(4, 4, 4).to(dtype=torch.int32) + module.forward(A, B) + + +class GCDDynamicModule(torch.nn.Module): + @export + @annotate_args( + [None, [(-1, -1, -1), torch.int32, True], [(-1, -1, -1), torch.int32, True]] + ) + def forward(self, A, B): + return torch.gcd(A, B) + + +@register_test_case(module_factory=lambda: GCDDynamicModule()) +def GCDDynamicModule_I32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.int32) + B = tu.rand(3, 4, 4).to(dtype=torch.int32) + module.forward(A, B) diff --git a/projects/pt1/tools/e2e_test.sh b/projects/pt1/tools/e2e_test.sh index a16929302a78..73d3361b6414 100755 --- a/projects/pt1/tools/e2e_test.sh +++ b/projects/pt1/tools/e2e_test.sh @@ -8,6 +8,4 @@ cd "$src_dir" # Ensure PYTHONPATH is set for export to child processes, even if empty. export PYTHONPATH=${PYTHONPATH-} -source $project_dir/.env - python -m e2e_testing.main "$@"