Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ONNX Rewriter and IR to simplify the mnb_to_qdq pass #1482

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

justinchuby
Copy link
Contributor

Describe your changes

Checklist before requesting a review

  • Add unit tests for this change.
  • Make sure all tests can pass.
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? If yes, give a description of this change to be included in the release notes.
  • Is this PR including examples changes? If yes, please remember to update example documentation in a follow-up PR.

(Optional) Issue link

Co-authored-by: Jambay Kinley <[email protected]>
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed

# Add Logic handling input 3

unpacked_weight_arrays = _unpack_weights(

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name \_unpack\_weights.
See https://docs.astral.sh/ruff/rules/undefined-name
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
@@ -7,8 +7,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict

import ml_dtypes

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'ml_dtypes' is not used.
@@ -7,8 +7,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict

import ml_dtypes

Check warning

Code scanning / lintrunner

PYLINT/W0611 Warning

Unused import ml_dtypes (unused-import)
See unused-import.
@@ -7,8 +7,10 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict

import ml_dtypes

Check warning

Code scanning / lintrunner

RUFF/F401 Warning

ml\_dtypes imported but unused.
See https://docs.astral.sh/ruff/rules/unused-import
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
olive/passes/onnx/mnb_to_qdq.py Fixed Show fixed Hide fixed
matmul = op.Add(matmul, bias)
return matmul

replace_mat_mul_n_bits = orp.RewriteRule(

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable replace_mat_mul_n_bits is not used.
return False
g_idx = g_idx.constant_value.numpy()
trivial_g_idx = np.arange(k, dtype=np.int32) // block_size
if not np.array_equal(g_idx, trivial_g_idx):

Check warning

Code scanning / lintrunner

RUFF/SIM103 Warning

Return the negated condition directly.
See https://docs.astral.sh/ruff/rules/needless-bool
g_idx = g_idx.constant_value.numpy()
trivial_g_idx = np.arange(k, dtype=np.int32) // block_size
if not np.array_equal(g_idx, trivial_g_idx):
# TODO: We can log why the pattern is not matched here

Check warning

Code scanning / lintrunner

RUFF/TD002 Warning

Missing author in TODO; try: # TODO(<author_name>): ... or # TODO @<author_name>: ....
See https://docs.astral.sh/ruff/rules/missing-todo-author
matmul = op.Add(matmul, bias)
return matmul

replace_mat_mul_n_bits = orp.RewriteRule(

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning

Unused variable 'replace_mat_mul_n_bits' (unused-variable)
See unused-variable.
matmul = op.Add(matmul, bias)
return matmul

replace_mat_mul_n_bits = orp.RewriteRule(

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable replace\_mat\_mul\_n\_bits is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable
graph: ir.Graph = context.graph
return value in graph.initializers.values()

def mat_mul_n_bits_pattern_check(context, *, q_weight, g_idx, mat_mul_n_bits_out: ir.Value, **_) -> bool:
Copy link
Contributor

@jambayk jambayk Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does q_weight here match for the input right before g_idx or it is whatever it is in the mat_mul_n_bits_pattern signature? The input before g_idx is qzero and can be optional. we want to check the second input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputs of the pattern-function (mat_mul_n_bits_pattern) are bound to values in the graph, and these values are passed in as keyword-arguments to the rewrite function here. So, the order here doesn't really matter, though I usually just copy-paste and use the same argument list for both.

del node.meta["N"]

# TODO(justinchuby): Register and remove initializers
ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Use a more robust version conversion process

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants