Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reassociate: add global reassociation algorithm (#6598) #6641

Open
wants to merge 1 commit into
base: release-1.8.2405
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions lib/Transforms/Scalar/Reassociate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
Expand All @@ -37,6 +37,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
using namespace llvm;
Expand Down Expand Up @@ -161,6 +162,13 @@ namespace {
DenseMap<BasicBlock*, unsigned> RankMap;
DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
SetVector<AssertingVH<Instruction> > RedoInsts;

// Arbitrary, but prevents quadratic behavior.
static const unsigned GlobalReassociateLimit = 10;
static const unsigned NumBinaryOps =
Instruction::BinaryOpsEnd - Instruction::BinaryOpsBegin;
DenseMap<std::pair<Value *, Value *>, unsigned> PairMap[NumBinaryOps];

bool MadeChange;
public:
static char ID; // Pass identification, replacement for typeid
Expand Down Expand Up @@ -196,6 +204,7 @@ namespace {
void EraseInst(Instruction *I);
void OptimizeInst(Instruction *I);
Instruction *canonicalizeNegConstExpr(Instruction *I);
void BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT);
};
}

Expand Down Expand Up @@ -2234,18 +2243,127 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) {
return;
}

if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) {
// Find the pair with the highest count in the pairmap and move it to the
// back of the list so that it can later be CSE'd.
// example:
// a*b*c*d*e
// if c*e is the most "popular" pair, we can express this as
// (((c*e)*d)*b)*a
unsigned Max = 1;
unsigned BestRank = 0;
std::pair<unsigned, unsigned> BestPair;
unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin;
for (unsigned i = 0; i < Ops.size() - 1; ++i)
for (unsigned j = i + 1; j < Ops.size(); ++j) {
unsigned Score = 0;
Value *Op0 = Ops[i].Op;
Value *Op1 = Ops[j].Op;
if (std::less<Value *>()(Op1, Op0))
std::swap(Op0, Op1);
auto it = PairMap[Idx].find({Op0, Op1});
if (it != PairMap[Idx].end())
Score += it->second;

unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank);
if (Score > Max || (Score == Max && MaxRank < BestRank)) {
BestPair = {i, j};
Max = Score;
BestRank = MaxRank;
}
}
if (Max > 1) {
auto Op0 = Ops[BestPair.first];
auto Op1 = Ops[BestPair.second];
Ops.erase(&Ops[BestPair.second]);
Ops.erase(&Ops[BestPair.first]);
Ops.push_back(Op0);
Ops.push_back(Op1);
}
}
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
RewriteExprTree(I, Ops);
}

void Reassociate::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) {
// Make a "pairmap" of how often each operand pair occurs.
for (BasicBlock *BI : RPOT) {
for (Instruction &I : *BI) {
if (!I.isAssociative())
continue;

// Ignore nodes that aren't at the root of trees.
if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode())
continue;

// Collect all operands in a single reassociable expression.
// Since Reassociate has already been run once, we can assume things
// are already canonical according to Reassociation's regime.
SmallVector<Value *, 8> Worklist = {I.getOperand(0), I.getOperand(1)};
SmallVector<Value *, 8> Ops;
while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) {
Value *Op = Worklist.pop_back_val();
Instruction *OpI = dyn_cast<Instruction>(Op);
if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) {
Ops.push_back(Op);
continue;
}
// Be paranoid about self-referencing expressions in unreachable code.
if (OpI->getOperand(0) != OpI)
Worklist.push_back(OpI->getOperand(0));
if (OpI->getOperand(1) != OpI)
Worklist.push_back(OpI->getOperand(1));
}
// Skip extremely long expressions.
if (Ops.size() > GlobalReassociateLimit)
continue;

// Add all pairwise combinations of operands to the pair map.
unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin;
SmallSet<std::pair<Value *, Value *>, 32> Visited;
for (unsigned i = 0; i < Ops.size() - 1; ++i) {
for (unsigned j = i + 1; j < Ops.size(); ++j) {
// Canonicalize operand orderings.
Value *Op0 = Ops[i];
Value *Op1 = Ops[j];
if (std::less<Value *>()(Op1, Op0))
std::swap(Op0, Op1);
if (!Visited.insert({Op0, Op1}).second)
continue;
auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1});
if (!res.second)
++res.first->second;
}
}
}
}
}

bool Reassociate::runOnFunction(Function &F) {
if (skipOptnoneFunction(F))
return false;

// Calculate the rank map for F
BuildRankMap(F);

// Build the pair map before running reassociate.
// Technically this would be more accurate if we did it after one round
// of reassociation, but in practice it doesn't seem to help much on
// real-world code, so don't waste the compile time running reassociate
// twice.
// If a user wants, they could expicitly run reassociate twice in their
// pass pipeline for further potential gains.
// It might also be possible to update the pair map during runtime, but the
// overhead of that may be large if there's many reassociable chains.
// TODO: RPOT
// Get the functions basic blocks in Reverse Post Order. This order is used by
// BuildRankMap to pre calculate ranks correctly. It also excludes dead basic
// blocks (it has been seen that the analysis in this pass could hang when
// analysing dead basic blocks).
ReversePostOrderTraversal<Function *> RPOT(&F);
BuildPairMap(RPOT);

MadeChange = false;
for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
// Optimize every instruction in the basic block.
Expand All @@ -2268,9 +2386,11 @@ bool Reassociate::runOnFunction(Function &F) {
}
}

// We are done with the rank map.
// We are done with the rank map and pair map.
RankMap.clear();
ValueRankMap.clear();
for (auto &Entry : PairMap)
Entry.clear();

return MadeChange;
}
15 changes: 15 additions & 0 deletions test/Transforms/Reassociate/basictest.ll
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,18 @@ define i32 @test15(i32 %X1, i32 %X2, i32 %X3) {
; CHECK-LABEL: @test15
; CHECK: and i1 %A, %B
}

; CHECK-LABEL: @test17
; CHECK: %[[A:.*]] = mul i32 %X4, %X3
; CHECK-NEXT: %[[C:.*]] = mul i32 %[[A]], %X1
; CHECK-NEXT: %[[D:.*]] = mul i32 %[[A]], %X2
; CHECK-NEXT: %[[E:.*]] = xor i32 %[[C]], %[[D]]
; CHECK-NEXT: ret i32 %[[E]]
define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
%A = mul i32 %X3, %X1
%B = mul i32 %X3, %X2
%C = mul i32 %A, %X4
%D = mul i32 %B, %X4
%E = xor i32 %C, %D
ret i32 %E
}
Loading