Branch data Line data Source code
1 : : // This file is a part of Julia. License is MIT: https://julialang.org/license 2 : : 3 : : // This pass finds floating-point operations on 16-bit (half precision) values, and replaces 4 : : // them by equivalent operations on 32-bit (single precision) values surrounded by a fpext 5 : : // and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved. 6 : : // 7 : : // Without this pass, back-ends that do not natively support half-precision (e.g. x86_64) 8 : : // similarly pattern-match half-precision operations with single-precision equivalents, but 9 : : // without truncating after every operation. Doing so breaks floating-point operations that 10 : : // assume precise semantics, such as Dekker arithmetic (as used in twiceprecision.jl). 11 : : // 12 : : // This pass is intended to run late in the pipeline, and should not be followed by 13 : : // instcombine. A run of GVN is recommended to clean-up identical conversions. 14 : : 15 : : #include "llvm-version.h" 16 : : 17 : : #include "support/dtypes.h" 18 : : #include "passes.h" 19 : : 20 : : #include <llvm/Pass.h> 21 : : #include <llvm/ADT/Statistic.h> 22 : : #include <llvm/IR/IRBuilder.h> 23 : : #include <llvm/IR/LegacyPassManager.h> 24 : : #include <llvm/IR/PassManager.h> 25 : : #include <llvm/IR/Module.h> 26 : : #include <llvm/IR/Verifier.h> 27 : : #include <llvm/Support/Debug.h> 28 : : 29 : : #define DEBUG_TYPE "demote_float16" 30 : : 31 : : using namespace llvm; 32 : : 33 : : STATISTIC(TotalChanged, "Total number of instructions changed"); 34 : : STATISTIC(TotalExt, "Total number of FPExt instructions inserted"); 35 : : STATISTIC(TotalTrunc, "Total number of FPTrunc instructions inserted"); 36 : : #define INST_STATISTIC(Opcode) STATISTIC(Opcode##Changed, "Number of " #Opcode " instructions changed") 37 : : INST_STATISTIC(FNeg); 38 : : INST_STATISTIC(FAdd); 39 : : INST_STATISTIC(FSub); 40 : : INST_STATISTIC(FMul); 41 : : INST_STATISTIC(FDiv); 42 : : INST_STATISTIC(FRem); 43 : : INST_STATISTIC(FCmp); 44 : : #undef INST_STATISTIC 45 : : 46 : : namespace { 47 : : 48 : 689527 : static bool demoteFloat16(Function &F) 49 : : { 50 : 689527 : auto &ctx = F.getContext(); 51 : 689527 : auto T_float16 = Type::getHalfTy(ctx); 52 : 689527 : auto T_float32 = Type::getFloatTy(ctx); 53 : : 54 : 1379050 : SmallVector<Instruction *, 0> erase; 55 [ + + ]: 7279390 : for (auto &BB : F) { 56 [ + + ]: 75288700 : for (auto &I : BB) { 57 [ + + ]: 68698800 : switch (I.getOpcode()) { 58 : 566628 : case Instruction::FNeg: 59 : : case Instruction::FAdd: 60 : : case Instruction::FSub: 61 : : case Instruction::FMul: 62 : : case Instruction::FDiv: 63 : : case Instruction::FRem: 64 : : case Instruction::FCmp: 65 : 566628 : break; 66 : 68132200 : default: 67 : 68137000 : continue; 68 : : } 69 : : 70 : : // skip @fastmath operations 71 : : // TODO: more fine-grained check (afn?) 72 [ + + ]: 566628 : if (I.isFast()) 73 : 4841 : continue; 74 : : 75 : 1123570 : IRBuilder<> builder(&I); 76 : : 77 : : // extend Float16 operands to Float32 78 : 561787 : bool OperandsChanged = false; 79 : 1123570 : SmallVector<Value *, 2> Operands(I.getNumOperands()); 80 [ + + ]: 1665800 : for (size_t i = 0; i < I.getNumOperands(); i++) { 81 : 1104020 : Value *Op = I.getOperand(i); 82 [ + + ]: 1104020 : if (Op->getType() == T_float16) { 83 : 30520 : ++TotalExt; 84 : 30520 : Op = builder.CreateFPExt(Op, T_float32); 85 : 30520 : OperandsChanged = true; 86 : : } 87 : 1104020 : Operands[i] = (Op); 88 : : } 89 : : 90 : : // recreate the instruction if any operands changed, 91 : : // truncating the result back to Float16 92 [ + + ]: 561787 : if (OperandsChanged) { 93 : : Value *NewI; 94 : 16116 : ++TotalChanged; 95 [ + + + + : 16116 : switch (I.getOpcode()) { + + + - ] 96 : 1712 : case Instruction::FNeg: 97 : : assert(Operands.size() == 1); 98 : 1712 : ++FNegChanged; 99 : 1712 : NewI = builder.CreateFNeg(Operands[0]); 100 : 1712 : break; 101 : 2686 : case Instruction::FAdd: 102 : : assert(Operands.size() == 2); 103 : 2686 : ++FAddChanged; 104 : 2686 : NewI = builder.CreateFAdd(Operands[0], Operands[1]); 105 : 2686 : break; 106 : 1914 : case Instruction::FSub: 107 : : assert(Operands.size() == 2); 108 : 1914 : ++FSubChanged; 109 : 1914 : NewI = builder.CreateFSub(Operands[0], Operands[1]); 110 : 1914 : break; 111 : 4141 : case Instruction::FMul: 112 : : assert(Operands.size() == 2); 113 : 4141 : ++FMulChanged; 114 : 4141 : NewI = builder.CreateFMul(Operands[0], Operands[1]); 115 : 4141 : break; 116 : 438 : case Instruction::FDiv: 117 : : assert(Operands.size() == 2); 118 : 438 : ++FDivChanged; 119 : 438 : NewI = builder.CreateFDiv(Operands[0], Operands[1]); 120 : 438 : break; 121 : 19 : case Instruction::FRem: 122 : : assert(Operands.size() == 2); 123 : 19 : ++FRemChanged; 124 : 19 : NewI = builder.CreateFRem(Operands[0], Operands[1]); 125 : 19 : break; 126 : 5206 : case Instruction::FCmp: 127 : : assert(Operands.size() == 2); 128 : 5206 : ++FCmpChanged; 129 : 10412 : NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(), 130 : 5206 : Operands[0], Operands[1]); 131 : 5206 : break; 132 : 0 : default: 133 : 0 : abort(); 134 : : } 135 : 16116 : cast<Instruction>(NewI)->copyMetadata(I); 136 : 16116 : cast<Instruction>(NewI)->copyFastMathFlags(&I); 137 [ + + ]: 16116 : if (NewI->getType() != I.getType()) { 138 : 10910 : ++TotalTrunc; 139 : 10910 : NewI = builder.CreateFPTrunc(NewI, I.getType()); 140 : : } 141 : 16116 : I.replaceAllUsesWith(NewI); 142 : 16116 : erase.push_back(&I); 143 : : } 144 : : } 145 : : } 146 : : 147 [ + + ]: 689527 : if (erase.size() > 0) { 148 [ + + ]: 17916 : for (auto V : erase) 149 : 16116 : V->eraseFromParent(); 150 : : assert(!verifyFunction(F, &errs())); 151 : 1800 : return true; 152 : : } 153 : : else 154 : 687727 : return false; 155 : : } 156 : : 157 : : } // end anonymous namespace 158 : : 159 : 0 : PreservedAnalyses DemoteFloat16::run(Function &F, FunctionAnalysisManager &AM) 160 : : { 161 [ # # ]: 0 : if (demoteFloat16(F)) { 162 : 0 : return PreservedAnalyses::allInSet<CFGAnalyses>(); 163 : : } 164 : 0 : return PreservedAnalyses::all(); 165 : : } 166 : : 167 : : namespace { 168 : : 169 : : struct DemoteFloat16Legacy : public FunctionPass { 170 : : static char ID; 171 : 1052 : DemoteFloat16Legacy() : FunctionPass(ID){}; 172 : : 173 : : private: 174 : 689527 : bool runOnFunction(Function &F) override { 175 : 689527 : return demoteFloat16(F); 176 : : } 177 : : }; 178 : : 179 : : char DemoteFloat16Legacy::ID = 0; 180 : : static RegisterPass<DemoteFloat16Legacy> 181 : : Y("DemoteFloat16", 182 : : "Demote Float16 operations to Float32 equivalents.", 183 : : false, 184 : : false); 185 : : } // end anonymous namespac 186 : : 187 : 1052 : Pass *createDemoteFloat16Pass() 188 : : { 189 : 1052 : return new DemoteFloat16Legacy(); 190 : : } 191 : : 192 : 0 : extern "C" JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass_impl(LLVMPassManagerRef PM) 193 : : { 194 : 0 : unwrap(PM)->add(createDemoteFloat16Pass()); 195 : 0 : }