LCOV - code coverage report
Current view: top level - src - llvm-demote-float16.cpp (source / functions) Hit Total Coverage
Test: [test only] commit 0f242327d2cc9bd130497f44b6350c924185606a Lines: 73 82 89.0 %
Date: 2022-07-16 23:42:53 Functions: 4 6 66.7 %
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: 27 30 90.0 %

           Branch data     Line data    Source code
       1                 :            : // This file is a part of Julia. License is MIT: https://julialang.org/license
       2                 :            : 
       3                 :            : // This pass finds floating-point operations on 16-bit (half precision) values, and replaces
       4                 :            : // them by equivalent operations on 32-bit (single precision) values surrounded by a fpext
       5                 :            : // and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved.
       6                 :            : //
       7                 :            : // Without this pass, back-ends that do not natively support half-precision (e.g. x86_64)
       8                 :            : // similarly pattern-match half-precision operations with single-precision equivalents, but
       9                 :            : // without truncating after every operation. Doing so breaks floating-point operations that
      10                 :            : // assume precise semantics, such as Dekker arithmetic (as used in twiceprecision.jl).
      11                 :            : //
      12                 :            : // This pass is intended to run late in the pipeline, and should not be followed by
      13                 :            : // instcombine. A run of GVN is recommended to clean-up identical conversions.
      14                 :            : 
      15                 :            : #include "llvm-version.h"
      16                 :            : 
      17                 :            : #include "support/dtypes.h"
      18                 :            : #include "passes.h"
      19                 :            : 
      20                 :            : #include <llvm/Pass.h>
      21                 :            : #include <llvm/ADT/Statistic.h>
      22                 :            : #include <llvm/IR/IRBuilder.h>
      23                 :            : #include <llvm/IR/LegacyPassManager.h>
      24                 :            : #include <llvm/IR/PassManager.h>
      25                 :            : #include <llvm/IR/Module.h>
      26                 :            : #include <llvm/IR/Verifier.h>
      27                 :            : #include <llvm/Support/Debug.h>
      28                 :            : 
      29                 :            : #define DEBUG_TYPE "demote_float16"
      30                 :            : 
      31                 :            : using namespace llvm;
      32                 :            : 
      33                 :            : STATISTIC(TotalChanged, "Total number of instructions changed");
      34                 :            : STATISTIC(TotalExt, "Total number of FPExt instructions inserted");
      35                 :            : STATISTIC(TotalTrunc, "Total number of FPTrunc instructions inserted");
      36                 :            : #define INST_STATISTIC(Opcode) STATISTIC(Opcode##Changed, "Number of " #Opcode " instructions changed")
      37                 :            : INST_STATISTIC(FNeg);
      38                 :            : INST_STATISTIC(FAdd);
      39                 :            : INST_STATISTIC(FSub);
      40                 :            : INST_STATISTIC(FMul);
      41                 :            : INST_STATISTIC(FDiv);
      42                 :            : INST_STATISTIC(FRem);
      43                 :            : INST_STATISTIC(FCmp);
      44                 :            : #undef INST_STATISTIC
      45                 :            : 
      46                 :            : namespace {
      47                 :            : 
      48                 :     689527 : static bool demoteFloat16(Function &F)
      49                 :            : {
      50                 :     689527 :     auto &ctx = F.getContext();
      51                 :     689527 :     auto T_float16 = Type::getHalfTy(ctx);
      52                 :     689527 :     auto T_float32 = Type::getFloatTy(ctx);
      53                 :            : 
      54                 :    1379050 :     SmallVector<Instruction *, 0> erase;
      55         [ +  + ]:    7279390 :     for (auto &BB : F) {
      56         [ +  + ]:   75288700 :         for (auto &I : BB) {
      57         [ +  + ]:   68698800 :             switch (I.getOpcode()) {
      58                 :     566628 :             case Instruction::FNeg:
      59                 :            :             case Instruction::FAdd:
      60                 :            :             case Instruction::FSub:
      61                 :            :             case Instruction::FMul:
      62                 :            :             case Instruction::FDiv:
      63                 :            :             case Instruction::FRem:
      64                 :            :             case Instruction::FCmp:
      65                 :     566628 :                 break;
      66                 :   68132200 :             default:
      67                 :   68137000 :                 continue;
      68                 :            :             }
      69                 :            : 
      70                 :            :             // skip @fastmath operations
      71                 :            :             // TODO: more fine-grained check (afn?)
      72         [ +  + ]:     566628 :             if (I.isFast())
      73                 :       4841 :                 continue;
      74                 :            : 
      75                 :    1123570 :             IRBuilder<> builder(&I);
      76                 :            : 
      77                 :            :             // extend Float16 operands to Float32
      78                 :     561787 :             bool OperandsChanged = false;
      79                 :    1123570 :             SmallVector<Value *, 2> Operands(I.getNumOperands());
      80         [ +  + ]:    1665800 :             for (size_t i = 0; i < I.getNumOperands(); i++) {
      81                 :    1104020 :                 Value *Op = I.getOperand(i);
      82         [ +  + ]:    1104020 :                 if (Op->getType() == T_float16) {
      83                 :      30520 :                     ++TotalExt;
      84                 :      30520 :                     Op = builder.CreateFPExt(Op, T_float32);
      85                 :      30520 :                     OperandsChanged = true;
      86                 :            :                 }
      87                 :    1104020 :                 Operands[i] = (Op);
      88                 :            :             }
      89                 :            : 
      90                 :            :             // recreate the instruction if any operands changed,
      91                 :            :             // truncating the result back to Float16
      92         [ +  + ]:     561787 :             if (OperandsChanged) {
      93                 :            :                 Value *NewI;
      94                 :      16116 :                 ++TotalChanged;
      95   [ +  +  +  +  :      16116 :                 switch (I.getOpcode()) {
             +  +  +  - ]
      96                 :       1712 :                 case Instruction::FNeg:
      97                 :            :                     assert(Operands.size() == 1);
      98                 :       1712 :                     ++FNegChanged;
      99                 :       1712 :                     NewI = builder.CreateFNeg(Operands[0]);
     100                 :       1712 :                     break;
     101                 :       2686 :                 case Instruction::FAdd:
     102                 :            :                     assert(Operands.size() == 2);
     103                 :       2686 :                     ++FAddChanged;
     104                 :       2686 :                     NewI = builder.CreateFAdd(Operands[0], Operands[1]);
     105                 :       2686 :                     break;
     106                 :       1914 :                 case Instruction::FSub:
     107                 :            :                     assert(Operands.size() == 2);
     108                 :       1914 :                     ++FSubChanged;
     109                 :       1914 :                     NewI = builder.CreateFSub(Operands[0], Operands[1]);
     110                 :       1914 :                     break;
     111                 :       4141 :                 case Instruction::FMul:
     112                 :            :                     assert(Operands.size() == 2);
     113                 :       4141 :                     ++FMulChanged;
     114                 :       4141 :                     NewI = builder.CreateFMul(Operands[0], Operands[1]);
     115                 :       4141 :                     break;
     116                 :        438 :                 case Instruction::FDiv:
     117                 :            :                     assert(Operands.size() == 2);
     118                 :        438 :                     ++FDivChanged;
     119                 :        438 :                     NewI = builder.CreateFDiv(Operands[0], Operands[1]);
     120                 :        438 :                     break;
     121                 :         19 :                 case Instruction::FRem:
     122                 :            :                     assert(Operands.size() == 2);
     123                 :         19 :                     ++FRemChanged;
     124                 :         19 :                     NewI = builder.CreateFRem(Operands[0], Operands[1]);
     125                 :         19 :                     break;
     126                 :       5206 :                 case Instruction::FCmp:
     127                 :            :                     assert(Operands.size() == 2);
     128                 :       5206 :                     ++FCmpChanged;
     129                 :      10412 :                     NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
     130                 :       5206 :                                               Operands[0], Operands[1]);
     131                 :       5206 :                     break;
     132                 :          0 :                 default:
     133                 :          0 :                     abort();
     134                 :            :                 }
     135                 :      16116 :                 cast<Instruction>(NewI)->copyMetadata(I);
     136                 :      16116 :                 cast<Instruction>(NewI)->copyFastMathFlags(&I);
     137         [ +  + ]:      16116 :                 if (NewI->getType() != I.getType()) {
     138                 :      10910 :                     ++TotalTrunc;
     139                 :      10910 :                     NewI = builder.CreateFPTrunc(NewI, I.getType());
     140                 :            :                 }
     141                 :      16116 :                 I.replaceAllUsesWith(NewI);
     142                 :      16116 :                 erase.push_back(&I);
     143                 :            :             }
     144                 :            :         }
     145                 :            :     }
     146                 :            : 
     147         [ +  + ]:     689527 :     if (erase.size() > 0) {
     148         [ +  + ]:      17916 :         for (auto V : erase)
     149                 :      16116 :             V->eraseFromParent();
     150                 :            :         assert(!verifyFunction(F, &errs()));
     151                 :       1800 :         return true;
     152                 :            :     }
     153                 :            :     else
     154                 :     687727 :         return false;
     155                 :            : }
     156                 :            : 
     157                 :            : } // end anonymous namespace
     158                 :            : 
     159                 :          0 : PreservedAnalyses DemoteFloat16::run(Function &F, FunctionAnalysisManager &AM)
     160                 :            : {
     161         [ #  # ]:          0 :     if (demoteFloat16(F)) {
     162                 :          0 :         return PreservedAnalyses::allInSet<CFGAnalyses>();
     163                 :            :     }
     164                 :          0 :     return PreservedAnalyses::all();
     165                 :            : }
     166                 :            : 
     167                 :            : namespace {
     168                 :            : 
     169                 :            : struct DemoteFloat16Legacy : public FunctionPass {
     170                 :            :     static char ID;
     171                 :       1052 :     DemoteFloat16Legacy() : FunctionPass(ID){};
     172                 :            : 
     173                 :            : private:
     174                 :     689527 :     bool runOnFunction(Function &F) override {
     175                 :     689527 :         return demoteFloat16(F);
     176                 :            :     }
     177                 :            : };
     178                 :            : 
     179                 :            : char DemoteFloat16Legacy::ID = 0;
     180                 :            : static RegisterPass<DemoteFloat16Legacy>
     181                 :            :         Y("DemoteFloat16",
     182                 :            :           "Demote Float16 operations to Float32 equivalents.",
     183                 :            :           false,
     184                 :            :           false);
     185                 :            : } // end anonymous namespac
     186                 :            : 
     187                 :       1052 : Pass *createDemoteFloat16Pass()
     188                 :            : {
     189                 :       1052 :     return new DemoteFloat16Legacy();
     190                 :            : }
     191                 :            : 
     192                 :          0 : extern "C" JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass_impl(LLVMPassManagerRef PM)
     193                 :            : {
     194                 :          0 :     unwrap(PM)->add(createDemoteFloat16Pass());
     195                 :          0 : }

Generated by: LCOV version 1.14