Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Inconsistency Module Structure in Relax Transform and Build Failure with InlinePrivateFunctions() #17479

Open
Thrsu opened this issue Oct 21, 2024 · 0 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Thrsu
Copy link
Contributor

Thrsu commented Oct 21, 2024

When applying the relax.transform.InlinePrivateFunctions() optimization to a Relax module using both Sequential and direct application methods, the resulting module structures are inconsistent. Additionally, when using relax.build() after applying the transformation directly, a build failure occurs with the following internal error:

InternalError: Check failed: (slot->value_computed) is false: PrimExpr T.int64(4) * n * m in function I.GlobalVar("main") has not been computed.

Expected behavior

The module structures generated by applying relax.transform.InlinePrivateFunctions() using Sequential or direct application should be consistent, and the module should compile successfully without any internal errors when using relax.build().

Actual behavior

  • The structures of the module differ between the two methods of applying the transformation.
  • When applying the transformation and using relax.build(), an internal error occurs, indicating that a computation involving n and m is not computed.

Steps to reproduce

import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_x2: T.handle, var_y2: T.handle, var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n, m = T.int64(), T.int64()
        x2 = T.match_buffer(var_x2, (n, m))
        y2 = T.match_buffer(var_y2, (n, m))
        T_add = T.match_buffer(var_T_add, (n, m))
        for ax0, ax1 in T.grid(n, m):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(x2[v_ax0, v_ax1], y2[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = x2[v_ax0, v_ax1] + y2[v_ax0, v_ax1]

    @R.function(private=True)
    def main_inner(x2: R.Tensor(("n", "m"), dtype="float32"), y2: R.Tensor(("n", "m"), dtype="float32")) -> R.Tensor(("n", "m"), dtype="float32"):
        n = T.int64()
        m = T.int64()
        cls = Module
        sum_inner = R.call_tir(cls.add, (x2, y2), out_sinfo=R.Tensor((n, m), dtype="float32"))
        return sum_inner

    @R.function
    def main(x1: R.Tensor((10, 5), dtype="float32"), y1: R.Tensor((10, 5), dtype="float32")) -> R.Tensor((10, 5), dtype="float32"):
        cls = Module
        sum_main: R.Tensor((10, 5), dtype="float32") = cls.main_inner(x1, y1)
        return sum_main

mod = Module

mod_seq = tvm.transform.Sequential([relax.transform.InlinePrivateFunctions(),])(mod)
mod = relax.transform.InlinePrivateFunctions()(mod)
#tvm.ir.assert_structural_equal(mod_seq, mod)
print(mod["main"].body.blocks[0].bindings[0].value.sinfo_args[0].shape.values[0])
with tvm.transform.PassContext(opt_level=4):
    ex = relax.build(mod, target='llvm')

Could you please help confirm if this is a bug in TVM or an issue with my usage?

@Thrsu Thrsu added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

1 participant