Branch data Line data Source code
1 : : // This file is a part of Julia. License is MIT: https://julialang.org/license 2 : : 3 : : #include "llvm-version.h" 4 : : #include "passes.h" 5 : : 6 : : #include <llvm-c/Core.h> 7 : : #include <llvm-c/Types.h> 8 : : 9 : : #include <llvm/ADT/Statistic.h> 10 : : #include <llvm/IR/Value.h> 11 : : #include <llvm/IR/LegacyPassManager.h> 12 : : #include <llvm/IR/PassManager.h> 13 : : #include <llvm/IR/Function.h> 14 : : #include <llvm/IR/Instructions.h> 15 : : #include <llvm/IR/IntrinsicInst.h> 16 : : #include <llvm/IR/Module.h> 17 : : #include <llvm/IR/Operator.h> 18 : : #include <llvm/IR/IRBuilder.h> 19 : : #include <llvm/IR/Verifier.h> 20 : : #include <llvm/Pass.h> 21 : : #include <llvm/Support/Debug.h> 22 : : 23 : : #include "julia.h" 24 : : #include "julia_assert.h" 25 : : 26 : : #define DEBUG_TYPE "combine_muladd" 27 : : #undef DEBUG 28 : : 29 : : using namespace llvm; 30 : : STATISTIC(TotalContracted, "Total number of multiplies marked for FMA"); 31 : : 32 : : /** 33 : : * Combine 34 : : * ``` 35 : : * %v0 = fmul ... %a, %b 36 : : * %v = fadd fast ... %v0, %c 37 : : * ``` 38 : : * to 39 : : * `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)` 40 : : * when `%v0` has no other use 41 : : */ 42 : : 43 : : // Return true if we changed the mulOp 44 : 0 : static bool checkCombine(Value *maybeMul) 45 : : { 46 : 0 : auto mulOp = dyn_cast<Instruction>(maybeMul); 47 [ # # # # : 0 : if (!mulOp || mulOp->getOpcode() != Instruction::FMul) # # ] 48 : 0 : return false; 49 [ # # ]: 0 : if (!mulOp->hasOneUse()) 50 : 0 : return false; 51 : : // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us. 52 : 0 : auto fmf = mulOp->getFastMathFlags(); 53 [ # # ]: 0 : if (!fmf.allowContract()) { 54 : 0 : ++TotalContracted; 55 : 0 : fmf.setAllowContract(true); 56 : 0 : mulOp->copyFastMathFlags(fmf); 57 : 0 : return true; 58 : : } 59 : 0 : return false; 60 : : } 61 : : 62 : 10 : static bool combineMulAdd(Function &F) 63 : : { 64 : 10 : bool modified = false; 65 [ + + ]: 46 : for (auto &BB: F) { 66 [ + + ]: 474 : for (auto it = BB.begin(); it != BB.end();) { 67 : 438 : auto &I = *it; 68 : 438 : it++; 69 [ - - + ]: 438 : switch (I.getOpcode()) { 70 : 0 : case Instruction::FAdd: { 71 [ # # ]: 0 : if (!I.isFast()) 72 : 0 : continue; 73 [ # # # # ]: 0 : modified |= checkCombine(I.getOperand(0)) || checkCombine(I.getOperand(1)); 74 : 0 : break; 75 : : } 76 : 0 : case Instruction::FSub: { 77 [ # # ]: 0 : if (!I.isFast()) 78 : 0 : continue; 79 [ # # # # ]: 0 : modified |= checkCombine(I.getOperand(0)) || checkCombine(I.getOperand(1)); 80 : 0 : break; 81 : : } 82 : 438 : default: 83 : 438 : break; 84 : : } 85 : : } 86 : : } 87 [ - + ]: 10 : assert(!verifyFunction(F, &errs())); 88 : 10 : return modified; 89 : : } 90 : : 91 : 0 : PreservedAnalyses CombineMulAdd::run(Function &F, FunctionAnalysisManager &AM) 92 : : { 93 [ # # ]: 0 : if (combineMulAdd(F)) { 94 : 0 : return PreservedAnalyses::allInSet<CFGAnalyses>(); 95 : : } 96 : 0 : return PreservedAnalyses::all(); 97 : : } 98 : : 99 : : 100 : : struct CombineMulAddLegacy : public FunctionPass { 101 : : static char ID; 102 : 2 : CombineMulAddLegacy() : FunctionPass(ID) 103 : 2 : {} 104 : : 105 : : private: 106 : 10 : bool runOnFunction(Function &F) override { 107 : 10 : return combineMulAdd(F); 108 : : } 109 : : }; 110 : : 111 : : char CombineMulAddLegacy::ID = 0; 112 : : static RegisterPass<CombineMulAddLegacy> X("CombineMulAdd", "Combine mul and add to muladd", 113 : : false /* Only looks at CFG */, 114 : : false /* Analysis Pass */); 115 : : 116 : 2 : Pass *createCombineMulAddPass() 117 : : { 118 : 2 : return new CombineMulAddLegacy(); 119 : : } 120 : : 121 : 0 : extern "C" JL_DLLEXPORT void LLVMExtraAddCombineMulAddPass_impl(LLVMPassManagerRef PM) 122 : : { 123 : 0 : unwrap(PM)->add(createCombineMulAddPass()); 124 : 0 : }