LCOV - code coverage report
Current view: top level - src - llvm-simdloop.cpp (source / functions) Hit Total Coverage
Test: [test only] commit 0f242327d2cc9bd130497f44b6350c924185606a Lines: 119 143 83.2 %
Date: 2022-07-16 23:42:53 Functions: 8 11 72.7 %
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: 60 80 75.0 %

           Branch data     Line data    Source code
       1                 :            : // This file is a part of Julia. License is MIT: https://julialang.org/license
       2                 :            : 
       3                 :            : #include "llvm-version.h"
       4                 :            : #include "passes.h"
       5                 :            : 
       6                 :            : // This file defines a LLVM pass that:
       7                 :            : // 1. Set's loop information in form of metadata
       8                 :            : // 2. If the metadata contains `julia.simdloop` finds reduction chains and marks
       9                 :            : //    floating-point operations as fast-math. `See enableUnsafeAlgebraIfReduction`.
      10                 :            : // 3. If the metadata contains `julia.ivdep` marks all memory accesses in the loop
      11                 :            : //    as independent of each other.
      12                 :            : //
      13                 :            : // The pass hinges on a call to a marker function that has metadata attached to it.
      14                 :            : // To construct the pass call `createLowerSimdLoopPass`.
      15                 :            : 
      16                 :            : #include "support/dtypes.h"
      17                 :            : 
      18                 :            : #include <llvm-c/Core.h>
      19                 :            : #include <llvm-c/Types.h>
      20                 :            : 
      21                 :            : #include <llvm/ADT/Statistic.h>
      22                 :            : #include <llvm/Analysis/LoopPass.h>
      23                 :            : #include <llvm/IR/LegacyPassManager.h>
      24                 :            : #include <llvm/IR/Instructions.h>
      25                 :            : #include <llvm/IR/Metadata.h>
      26                 :            : #include <llvm/IR/Verifier.h>
      27                 :            : #include <llvm/Support/Debug.h>
      28                 :            : 
      29                 :            : #include "julia_assert.h"
      30                 :            : 
      31                 :            : #define DEBUG_TYPE "lower_simd_loop"
      32                 :            : 
      33                 :            : using namespace llvm;
      34                 :            : 
      35                 :            : STATISTIC(TotalMarkedLoops, "Total number of loops marked with simdloop");
      36                 :            : STATISTIC(IVDepLoops, "Number of loops with no loop-carried dependencies");
      37                 :            : STATISTIC(SimdLoops, "Number of loops with SIMD instructions");
      38                 :            : STATISTIC(IVDepInstructions, "Number of instructions marked ivdep");
      39                 :            : STATISTIC(ReductionChains, "Number of reduction chains folded");
      40                 :            : STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction chain");
      41                 :            : STATISTIC(AddChains, "Addition reduction chains");
      42                 :            : STATISTIC(MulChains, "Multiply reduction chains");
      43                 :            : 
      44                 :            : namespace {
      45                 :            : 
      46                 :       3456 : static unsigned getReduceOpcode(Instruction *J, Instruction *operand)
      47                 :            : {
      48   [ +  +  -  +  :       3456 :     switch (J->getOpcode()) {
                      + ]
      49                 :         36 :     case Instruction::FSub:
      50         [ -  + ]:         36 :         if (J->getOperand(0) != operand)
      51                 :          0 :             return 0;
      52                 :            :         JL_FALLTHROUGH;
      53                 :            :     case Instruction::FAdd:
      54                 :       2102 :         return Instruction::FAdd;
      55                 :          0 :     case Instruction::FDiv:
      56         [ #  # ]:          0 :         if (J->getOperand(0) != operand)
      57                 :          0 :             return 0;
      58                 :            :         JL_FALLTHROUGH;
      59                 :            :     case Instruction::FMul:
      60                 :         41 :         return Instruction::FMul;
      61                 :       1313 :     default:
      62                 :       1313 :         return 0;
      63                 :            :     }
      64                 :            : }
      65                 :            : 
      66                 :            : /// If Phi is part of a reduction cycle of FAdd, FSub, FMul or FDiv,
      67                 :            : /// mark the ops as permitting reassociation/commuting.
      68                 :            : /// As of LLVM 4.0, FDiv is not handled by the loop vectorizer
      69                 :      25399 : static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop *L)
      70                 :            : {
      71                 :            :     typedef SmallVector<Instruction*, 8> chainVector;
      72                 :      25399 :     chainVector chain;
      73                 :            :     Instruction *J;
      74                 :      25399 :     unsigned opcode = 0;
      75                 :      25399 :     for (Instruction *I = Phi; ; I=J) {
      76                 :      27540 :         J = NULL;
      77                 :            :         // Find the user of instruction I that is within loop L.
      78         [ +  + ]:      56968 :         for (User *UI : I->users()) { /*}*/
      79                 :      51705 :             Instruction *U = cast<Instruction>(UI);
      80         [ +  + ]:      51705 :             if (L->contains(U)) {
      81         [ +  + ]:      49817 :                 if (J) {
      82                 :            :                     LLVM_DEBUG(dbgs() << "LSL: not a reduction var because op has two internal uses: " << *I << "\n");
      83                 :      22277 :                     return;
      84                 :            :                 }
      85                 :      27540 :                 J = U;
      86                 :            :             }
      87                 :            :         }
      88         [ -  + ]:       5263 :         if (!J) {
      89                 :            :             LLVM_DEBUG(dbgs() << "LSL: chain prematurely terminated at " << *I << "\n");
      90                 :          0 :             return;
      91                 :            :         }
      92         [ +  + ]:       5263 :         if (J == Phi) {
      93                 :            :             // Found the entire chain.
      94                 :       1807 :             break;
      95                 :            :         }
      96         [ +  + ]:       3456 :         if (opcode) {
      97                 :            :             // Check that arithmetic op matches prior arithmetic ops in the chain.
      98         [ +  + ]:          4 :             if (getReduceOpcode(J, I) != opcode) {
      99                 :            :                 LLVM_DEBUG(dbgs() << "LSL: chain broke at " << *J << " because of wrong opcode\n");
     100                 :          2 :                 return;
     101                 :            :             }
     102                 :            :         }
     103                 :            :         else {
     104                 :            :             // First arithmetic op in the chain.
     105                 :       3452 :             opcode = getReduceOpcode(J, I);
     106         [ +  + ]:       3452 :             if (!opcode) {
     107                 :            :                 LLVM_DEBUG(dbgs() << "LSL: first arithmetic op in chain is uninteresting" << *J << "\n");
     108                 :       1313 :                 return;
     109                 :            :             }
     110                 :            :         }
     111                 :       2141 :         chain.push_back(J);
     112                 :       2141 :     }
     113      [ +  +  - ]:       1807 :     switch (opcode) {
     114                 :       1769 :         case Instruction::FAdd:
     115                 :       1769 :             ++AddChains;
     116                 :       1769 :             break;
     117                 :         38 :         case Instruction::FMul:
     118                 :         38 :             ++MulChains;
     119                 :         38 :             break;
     120                 :            :     }
     121                 :       1807 :     ++ReductionChains;
     122         [ +  + ]:       3616 :     for (chainVector::const_iterator K=chain.begin(); K!=chain.end(); ++K) {
     123                 :            :         LLVM_DEBUG(dbgs() << "LSL: marking " << **K << "\n");
     124                 :       1809 :         (*K)->setFast(true);
     125                 :       1809 :         ++ReductionChainLength;
     126                 :            :     }
     127                 :            : }
     128                 :            : 
     129                 :      15136 : static bool markLoopInfo(Module &M, Function *marker, function_ref<LoopInfo &(Function &)> GetLI)
     130                 :            : {
     131                 :      15136 :     bool Changed = false;
     132                 :      15136 :     std::vector<Instruction*> ToDelete;
     133         [ +  + ]:      35917 :     for (User *U : marker->users()) {
     134                 :      20781 :         ++TotalMarkedLoops;
     135                 :      20781 :         Instruction *I = cast<Instruction>(U);
     136                 :      20781 :         ToDelete.push_back(I);
     137                 :            : 
     138                 :      20781 :         LoopInfo &LI = GetLI(*I->getParent()->getParent());
     139                 :      20781 :         Loop *L = LI.getLoopFor(I->getParent());
     140                 :      20781 :         I->removeFromParent();
     141         [ +  + ]:      20781 :         if (!L)
     142                 :         77 :             continue;
     143                 :            : 
     144                 :            :         LLVM_DEBUG(dbgs() << "LSL: loopinfo marker found\n");
     145                 :      20704 :         bool simd = false;
     146                 :      20704 :         bool ivdep = false;
     147                 :      41408 :         SmallVector<Metadata *, 8> MDs;
     148                 :            : 
     149                 :      20704 :         BasicBlock *Lh = L->getHeader();
     150                 :            :         LLVM_DEBUG(dbgs() << "LSL: loop header: " << *Lh << "\n");
     151                 :            : 
     152                 :            :         // Reserve first location for self reference to the LoopID metadata node.
     153                 :      20704 :         TempMDTuple TempNode = MDNode::getTemporary(Lh->getContext(), None);
     154                 :      20704 :         MDs.push_back(TempNode.get());
     155                 :            : 
     156                 :            :         // Walk `julia.loopinfo` metadata and filter out `julia.simdloop` and `julia.ivdep`
     157         [ +  - ]:      20704 :         if (I->hasMetadataOtherThanDebugLoc()) {
     158                 :      20704 :             MDNode *JLMD= I->getMetadata("julia.loopinfo");
     159         [ +  - ]:      20704 :             if (JLMD) {
     160                 :            :                 LLVM_DEBUG(dbgs() << "LSL: has julia.loopinfo metadata with " << JLMD->getNumOperands() <<" operands\n");
     161         [ +  + ]:      41412 :                 for (unsigned i = 0, ie = JLMD->getNumOperands(); i < ie; ++i) {
     162                 :      20708 :                     Metadata *Op = JLMD->getOperand(i);
     163                 :      20708 :                     const MDString *S = dyn_cast<MDString>(Op);
     164         [ +  - ]:      20708 :                     if (S) {
     165                 :            :                         LLVM_DEBUG(dbgs() << "LSL: found " << S->getString() << "\n");
     166         [ +  - ]:      20708 :                         if (S->getString().startswith("julia")) {
     167         [ +  + ]:      20708 :                             if (S->getString().equals("julia.simdloop"))
     168                 :      20704 :                                 simd = true;
     169         [ +  + ]:      20708 :                             if (S->getString().equals("julia.ivdep"))
     170                 :          4 :                                 ivdep = true;
     171                 :      20708 :                             continue;
     172                 :            :                         }
     173                 :            :                     }
     174                 :          0 :                     MDs.push_back(Op);
     175                 :            :                 }
     176                 :            :             }
     177                 :            :         }
     178                 :            : 
     179                 :            :         LLVM_DEBUG(dbgs() << "LSL: simd: " << simd << " ivdep: " << ivdep << "\n");
     180                 :            : 
     181                 :      20704 :         MDNode *n = L->getLoopID();
     182         [ +  + ]:      20704 :         if (n) {
     183                 :            :             // Loop already has a LoopID so copy over Metadata
     184                 :            :             // original loop id is operand 0
     185         [ -  + ]:         44 :             for (unsigned i = 1, ie = n->getNumOperands(); i < ie; ++i) {
     186                 :          0 :                 Metadata *Op = n->getOperand(i);
     187                 :          0 :                 MDs.push_back(Op);
     188                 :            :             }
     189                 :            :         }
     190                 :      20704 :         MDNode *LoopID = MDNode::getDistinct(Lh->getContext(), MDs);
     191                 :            :         // Replace the temporary node with a self-reference.
     192                 :      20704 :         LoopID->replaceOperandWith(0, LoopID);
     193                 :      20704 :         L->setLoopID(LoopID);
     194         [ -  + ]:      20704 :         assert(L->getLoopID());
     195                 :            : 
     196                 :      20704 :         MDNode *m = MDNode::get(Lh->getContext(), ArrayRef<Metadata *>(LoopID));
     197                 :            : 
     198                 :            :         // If ivdep is true we assume that there is no memory dependency between loop iterations
     199                 :            :         // This is a fairly strong assumption and does often not hold true for generic code.
     200         [ +  + ]:      20704 :         if (ivdep) {
     201                 :          4 :             ++IVDepLoops;
     202                 :            :             // Mark memory references so that Loop::isAnnotatedParallel will return true for this loop.
     203         [ +  + ]:         24 :             for (BasicBlock *BB : L->blocks()) {
     204         [ +  + ]:        136 :                for (Instruction &I : *BB) {
     205         [ +  + ]:        116 :                    if (I.mayReadOrWriteMemory()) {
     206                 :         28 :                        ++IVDepInstructions;
     207                 :         28 :                        I.setMetadata(LLVMContext::MD_mem_parallel_loop_access, m);
     208                 :            :                    }
     209                 :            :                }
     210                 :            :             }
     211         [ -  + ]:          4 :             assert(L->isAnnotatedParallel());
     212                 :            :         }
     213                 :            : 
     214         [ +  - ]:      20704 :         if (simd) {
     215                 :      20704 :             ++SimdLoops;
     216                 :            :             // Mark floating-point reductions as okay to reassociate/commute.
     217         [ +  - ]:      46103 :             for (BasicBlock::iterator I = Lh->begin(), E = Lh->end(); I != E; ++I) {
     218         [ +  + ]:      46103 :                 if (PHINode *Phi = dyn_cast<PHINode>(I))
     219                 :      25399 :                     enableUnsafeAlgebraIfReduction(Phi, L);
     220                 :            :                 else
     221                 :      20704 :                     break;
     222                 :            :             }
     223                 :            :         }
     224                 :            : 
     225                 :      20704 :         Changed = true;
     226                 :            :     }
     227                 :            : 
     228         [ +  + ]:      35917 :     for (Instruction *I : ToDelete)
     229                 :      20781 :         I->deleteValue();
     230                 :      15136 :     marker->eraseFromParent();
     231                 :            : 
     232         [ -  + ]:      15136 :     assert(!verifyModule(M));
     233                 :      15136 :     return Changed;
     234                 :            : }
     235                 :            : 
     236                 :            : } // end anonymous namespace
     237                 :            : 
     238                 :            : 
     239                 :            : /// This pass should run after reduction variables have been converted to phi nodes,
     240                 :            : /// otherwise floating-point reductions might not be recognized as such and
     241                 :            : /// prevent SIMDization.
     242                 :            : 
     243                 :            : 
     244                 :          0 : PreservedAnalyses LowerSIMDLoop::run(Module &M, ModuleAnalysisManager &AM)
     245                 :            : {
     246                 :          0 :     Function *loopinfo_marker = M.getFunction("julia.loopinfo_marker");
     247                 :            : 
     248         [ #  # ]:          0 :     if (!loopinfo_marker)
     249                 :          0 :         return PreservedAnalyses::all();
     250                 :            : 
     251                 :            :     FunctionAnalysisManager &FAM =
     252                 :          0 :       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
     253                 :            : 
     254                 :          0 :     auto GetLI = [&FAM](Function &F) -> LoopInfo & {
     255                 :          0 :         return FAM.getResult<LoopAnalysis>(F);
     256                 :          0 :     };
     257                 :            : 
     258         [ #  # ]:          0 :     if (markLoopInfo(M, loopinfo_marker, GetLI)) {
     259                 :          0 :         auto preserved = PreservedAnalyses::allInSet<CFGAnalyses>();
     260                 :          0 :         preserved.preserve<LoopAnalysis>();
     261                 :          0 :         return preserved;
     262                 :            :     }
     263                 :            : 
     264                 :          0 :     return PreservedAnalyses::all();
     265                 :            : }
     266                 :            : 
     267                 :            : namespace {
     268                 :            : class LowerSIMDLoopLegacy : public ModulePass {
     269                 :            :     //LowerSIMDLoop Impl;
     270                 :            : 
     271                 :            : public:
     272                 :            :   static char ID;
     273                 :            : 
     274                 :       1530 :   LowerSIMDLoopLegacy() : ModulePass(ID) {
     275                 :       1530 :   }
     276                 :            : 
     277                 :     359224 :   bool runOnModule(Module &M) override {
     278                 :     359224 :     bool Changed = false;
     279                 :            : 
     280                 :     359224 :     Function *loopinfo_marker = M.getFunction("julia.loopinfo_marker");
     281                 :            : 
     282                 :      20781 :     auto GetLI = [this](Function &F) -> LoopInfo & {
     283                 :      20781 :         return getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
     284                 :     359224 :     };
     285                 :            : 
     286         [ +  + ]:     359224 :     if (loopinfo_marker)
     287                 :      15136 :         Changed |= markLoopInfo(M, loopinfo_marker, GetLI);
     288                 :            : 
     289                 :     359224 :     return Changed;
     290                 :            :   }
     291                 :            : 
     292                 :       1530 :   void getAnalysisUsage(AnalysisUsage &AU) const override
     293                 :            :   {
     294                 :       1530 :       ModulePass::getAnalysisUsage(AU);
     295                 :       1530 :       AU.addRequired<LoopInfoWrapperPass>();
     296                 :       1530 :       AU.addPreserved<LoopInfoWrapperPass>();
     297                 :       1530 :       AU.setPreservesCFG();
     298                 :       1530 :   }
     299                 :            : };
     300                 :            : 
     301                 :            : } // end anonymous namespace
     302                 :            : 
     303                 :            : char LowerSIMDLoopLegacy::ID = 0;
     304                 :            : 
     305                 :            : static RegisterPass<LowerSIMDLoopLegacy> X("LowerSIMDLoop", "LowerSIMDLoop Pass",
     306                 :            :                                      false /* Only looks at CFG */,
     307                 :            :                                      false /* Analysis Pass */);
     308                 :            : 
     309                 :       1530 : JL_DLLEXPORT Pass *createLowerSimdLoopPass()
     310                 :            : {
     311                 :       1530 :     return new LowerSIMDLoopLegacy();
     312                 :            : }
     313                 :            : 
     314                 :          0 : extern "C" JL_DLLEXPORT void LLVMExtraAddLowerSimdLoopPass_impl(LLVMPassManagerRef PM)
     315                 :            : {
     316                 :          0 :     unwrap(PM)->add(createLowerSimdLoopPass());
     317                 :          0 : }

Generated by: LCOV version 1.14