From a3688d9d34bf9d0802f96402f524a6ec58270cc6 Mon Sep 17 00:00:00 2001 From: M N Ganesan Date: Wed, 9 Oct 2024 13:09:53 +0000 Subject: [PATCH] [Frontend][ArgParse] Compile with default(LLVM) target and build with BYOC(#17454) It is a unique use-case to check the default target(LLVM), though TVM is built with BYOC(MRVL-ON) The config of Codegen(BYOC) contains default values for configuration/options, it is extracted during _generate_codegen_args. In command line processing, validate_target_args checks if there are add-on options and it expects that particular target to be given explicitly in command line. Here, it is test for default (LLVM) path only, hence validate_target_args need to ignore the codegen's configuration for default target. Signed-off-by: M N Ganesan --- python/tvm/driver/tvmc/target.py | 4 ++++ tests/python/driver/tvmc/test_target_options.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index b5eee0482377..9e9a9ada5c1b 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -179,6 +179,10 @@ def validate_targets(parse_targets, additional_target_options=None): ) if additional_target_options is not None: + # Add-on target options are passed from codegen's config(BYOC) which has pass_default=True + # Eg: --target="llvm" + if len(tvm_targets) == 1: + return for target_name in additional_target_options: if not any([target for target in parse_targets if target["name"] == target_name]): first_option = list(additional_target_options[target_name].keys())[0] diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index d98a8d588e22..0820a0f92f2e 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -86,6 +86,20 @@ def test_default_arg_for_mrvl_hybrid(): assert parsed.target_mrvl_num_tiles == 8 +@tvm.testing.requires_mrvl +def test_default_arg_for_mrvl_hybrid(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--target=mrvl, llvm", + ] + ) + assert parsed.target == "mrvl, llvm" + assert parsed.target_mrvl_mcpu == "cn10ka" + assert parsed.target_mrvl_num_tiles == 8 + + @tvm.testing.requires_cmsisnn def test_mapping_target_args(): parser = argparse.ArgumentParser()