LCOV - code coverage report
Current view: top level - src - llvm-muladd.cpp (source / functions) Hit Total Coverage
Test: [build process] commit ef510b1f346f4c9f9d86eaceace5ca54961a1dbc Lines: 17 47 36.2 %
Date: 2022-07-17 01:01:28 Functions: 4 7 57.1 %
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: 6 33 18.2 %

           Branch data     Line data    Source code
       1                 :            : // This file is a part of Julia. License is MIT: https://julialang.org/license
       2                 :            : 
       3                 :            : #include "llvm-version.h"
       4                 :            : #include "passes.h"
       5                 :            : 
       6                 :            : #include <llvm-c/Core.h>
       7                 :            : #include <llvm-c/Types.h>
       8                 :            : 
       9                 :            : #include <llvm/ADT/Statistic.h>
      10                 :            : #include <llvm/IR/Value.h>
      11                 :            : #include <llvm/IR/LegacyPassManager.h>
      12                 :            : #include <llvm/IR/PassManager.h>
      13                 :            : #include <llvm/IR/Function.h>
      14                 :            : #include <llvm/IR/Instructions.h>
      15                 :            : #include <llvm/IR/IntrinsicInst.h>
      16                 :            : #include <llvm/IR/Module.h>
      17                 :            : #include <llvm/IR/Operator.h>
      18                 :            : #include <llvm/IR/IRBuilder.h>
      19                 :            : #include <llvm/IR/Verifier.h>
      20                 :            : #include <llvm/Pass.h>
      21                 :            : #include <llvm/Support/Debug.h>
      22                 :            : 
      23                 :            : #include "julia.h"
      24                 :            : #include "julia_assert.h"
      25                 :            : 
      26                 :            : #define DEBUG_TYPE "combine_muladd"
      27                 :            : #undef DEBUG
      28                 :            : 
      29                 :            : using namespace llvm;
      30                 :            : STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");
      31                 :            : 
      32                 :            : /**
      33                 :            :  * Combine
      34                 :            :  * ```
      35                 :            :  * %v0 = fmul ... %a, %b
      36                 :            :  * %v = fadd fast ... %v0, %c
      37                 :            :  * ```
      38                 :            :  * to
      39                 :            :  * `%v = call fast @llvm.fmuladd.<...>(... %a, ... %b, ... %c)`
      40                 :            :  * when `%v0` has no other use
      41                 :            :  */
      42                 :            : 
      43                 :            : // Return true if we changed the mulOp
      44                 :          0 : static bool checkCombine(Value *maybeMul)
      45                 :            : {
      46                 :          0 :     auto mulOp = dyn_cast<Instruction>(maybeMul);
      47   [ #  #  #  #  :          0 :     if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
                   #  # ]
      48                 :          0 :         return false;
      49         [ #  # ]:          0 :     if (!mulOp->hasOneUse())
      50                 :          0 :         return false;
      51                 :            :     // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
      52                 :          0 :     auto fmf = mulOp->getFastMathFlags();
      53         [ #  # ]:          0 :     if (!fmf.allowContract()) {
      54                 :          0 :         ++TotalContracted;
      55                 :          0 :         fmf.setAllowContract(true);
      56                 :          0 :         mulOp->copyFastMathFlags(fmf);
      57                 :          0 :         return true;
      58                 :            :     }
      59                 :          0 :     return false;
      60                 :            : }
      61                 :            : 
      62                 :         10 : static bool combineMulAdd(Function &F)
      63                 :            : {
      64                 :         10 :     bool modified = false;
      65         [ +  + ]:         46 :     for (auto &BB: F) {
      66         [ +  + ]:        474 :         for (auto it = BB.begin(); it != BB.end();) {
      67                 :        438 :             auto &I = *it;
      68                 :        438 :             it++;
      69      [ -  -  + ]:        438 :             switch (I.getOpcode()) {
      70                 :          0 :             case Instruction::FAdd: {
      71         [ #  # ]:          0 :                 if (!I.isFast())
      72                 :          0 :                     continue;
      73   [ #  #  #  # ]:          0 :                 modified |= checkCombine(I.getOperand(0)) || checkCombine(I.getOperand(1));
      74                 :          0 :                 break;
      75                 :            :             }
      76                 :          0 :             case Instruction::FSub: {
      77         [ #  # ]:          0 :                 if (!I.isFast())
      78                 :          0 :                     continue;
      79   [ #  #  #  # ]:          0 :                 modified |= checkCombine(I.getOperand(0)) || checkCombine(I.getOperand(1));
      80                 :          0 :                 break;
      81                 :            :             }
      82                 :        438 :             default:
      83                 :        438 :                 break;
      84                 :            :             }
      85                 :            :         }
      86                 :            :     }
      87         [ -  + ]:         10 :     assert(!verifyFunction(F, &errs()));
      88                 :         10 :     return modified;
      89                 :            : }
      90                 :            : 
      91                 :          0 : PreservedAnalyses CombineMulAdd::run(Function &F, FunctionAnalysisManager &AM)
      92                 :            : {
      93         [ #  # ]:          0 :     if (combineMulAdd(F)) {
      94                 :          0 :         return PreservedAnalyses::allInSet<CFGAnalyses>();
      95                 :            :     }
      96                 :          0 :     return PreservedAnalyses::all();
      97                 :            : }
      98                 :            : 
      99                 :            : 
     100                 :            : struct CombineMulAddLegacy : public FunctionPass {
     101                 :            :     static char ID;
     102                 :          2 :     CombineMulAddLegacy() : FunctionPass(ID)
     103                 :          2 :     {}
     104                 :            : 
     105                 :            : private:
     106                 :         10 :     bool runOnFunction(Function &F) override {
     107                 :         10 :         return combineMulAdd(F);
     108                 :            :     }
     109                 :            : };
     110                 :            : 
     111                 :            : char CombineMulAddLegacy::ID = 0;
     112                 :            : static RegisterPass<CombineMulAddLegacy> X("CombineMulAdd", "Combine mul and add to muladd",
     113                 :            :                                      false /* Only looks at CFG */,
     114                 :            :                                      false /* Analysis Pass */);
     115                 :            : 
     116                 :          2 : Pass *createCombineMulAddPass()
     117                 :            : {
     118                 :          2 :     return new CombineMulAddLegacy();
     119                 :            : }
     120                 :            : 
     121                 :          0 : extern "C" JL_DLLEXPORT void LLVMExtraAddCombineMulAddPass_impl(LLVMPassManagerRef PM)
     122                 :            : {
     123                 :          0 :     unwrap(PM)->add(createCombineMulAddPass());
     124                 :          0 : }

Generated by: LCOV version 1.14