Branch data Line data Source code
1 : : // This file is a part of Julia. License is MIT: https://julialang.org/license 2 : : 3 : : // This pass finds floating-point operations on 16-bit (half precision) values, and replaces 4 : : // them by equivalent operations on 32-bit (single precision) values surrounded by a fpext 5 : : // and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved. 6 : : // 7 : : // Without this pass, back-ends that do not natively support half-precision (e.g. x86_64) 8 : : // similarly pattern-match half-precision operations with single-precision equivalents, but 9 : : // without truncating after every operation. Doing so breaks floating-point operations that 10 : : // assume precise semantics, such as Dekker arithmetic (as used in twiceprecision.jl). 11 : : // 12 : : // This pass is intended to run late in the pipeline, and should not be followed by 13 : : // instcombine. A run of GVN is recommended to clean-up identical conversions. 14 : : 15 : : #include "llvm-version.h" 16 : : 17 : : #include "support/dtypes.h" 18 : : #include "passes.h" 19 : : 20 : : #include <llvm/Pass.h> 21 : : #include <llvm/ADT/Statistic.h> 22 : : #include <llvm/IR/IRBuilder.h> 23 : : #include <llvm/IR/LegacyPassManager.h> 24 : : #include <llvm/IR/PassManager.h> 25 : : #include <llvm/IR/Module.h> 26 : : #include <llvm/IR/Verifier.h> 27 : : #include <llvm/Support/Debug.h> 28 : : 29 : : #define DEBUG_TYPE "demote_float16" 30 : : 31 : : using namespace llvm; 32 : : 33 : : STATISTIC(TotalChanged, "Total number of instructions changed"); 34 : : STATISTIC(TotalExt, "Total number of FPExt instructions inserted"); 35 : : STATISTIC(TotalTrunc, "Total number of FPTrunc instructions inserted"); 36 : : #define INST_STATISTIC(Opcode) STATISTIC(Opcode##Changed, "Number of " #Opcode " instructions changed") 37 : : INST_STATISTIC(FNeg); 38 : : INST_STATISTIC(FAdd); 39 : : INST_STATISTIC(FSub); 40 : : INST_STATISTIC(FMul); 41 : : INST_STATISTIC(FDiv); 42 : : INST_STATISTIC(FRem); 43 : : INST_STATISTIC(FCmp); 44 : : #undef INST_STATISTIC 45 : : 46 : : namespace { 47 : : 48 : 137811 : static bool demoteFloat16(Function &F) 49 : : { 50 : 137811 : auto &ctx = F.getContext(); 51 : 137811 : auto T_float16 = Type::getHalfTy(ctx); 52 : 137811 : auto T_float32 = Type::getFloatTy(ctx); 53 : : 54 : 275622 : SmallVector<Instruction *, 0> erase; 55 [ + + ]: 2012080 : for (auto &BB : F) { 56 [ + + ]: 27951800 : for (auto &I : BB) { 57 [ + + ]: 26077500 : switch (I.getOpcode()) { 58 : 7948 : case Instruction::FNeg: 59 : : case Instruction::FAdd: 60 : : case Instruction::FSub: 61 : : case Instruction::FMul: 62 : : case Instruction::FDiv: 63 : : case Instruction::FRem: 64 : : case Instruction::FCmp: 65 : 7948 : break; 66 : 26069500 : default: 67 : 26069500 : continue; 68 : : } 69 : : 70 : : // skip @fastmath operations 71 : : // TODO: more fine-grained check (afn?) 72 [ - + ]: 7948 : if (I.isFast()) 73 : 0 : continue; 74 : : 75 : 15896 : IRBuilder<> builder(&I); 76 : : 77 : : // extend Float16 operands to Float32 78 : 7948 : bool OperandsChanged = false; 79 : 15896 : SmallVector<Value *, 2> Operands(I.getNumOperands()); 80 [ + + ]: 23706 : for (size_t i = 0; i < I.getNumOperands(); i++) { 81 : 15758 : Value *Op = I.getOperand(i); 82 [ + + ]: 15758 : if (Op->getType() == T_float16) { 83 : 40 : ++TotalExt; 84 : 40 : Op = builder.CreateFPExt(Op, T_float32); 85 : 40 : OperandsChanged = true; 86 : : } 87 : 15758 : Operands[i] = (Op); 88 : : } 89 : : 90 : : // recreate the instruction if any operands changed, 91 : : // truncating the result back to Float16 92 [ + + ]: 7948 : if (OperandsChanged) { 93 : : Value *NewI; 94 : 20 : ++TotalChanged; 95 [ - + + - : 20 : switch (I.getOpcode()) { - - + - ] 96 : 0 : case Instruction::FNeg: 97 : : assert(Operands.size() == 1); 98 : 0 : ++FNegChanged; 99 : 0 : NewI = builder.CreateFNeg(Operands[0]); 100 : 0 : break; 101 : 4 : case Instruction::FAdd: 102 : : assert(Operands.size() == 2); 103 : 4 : ++FAddChanged; 104 : 4 : NewI = builder.CreateFAdd(Operands[0], Operands[1]); 105 : 4 : break; 106 : 6 : case Instruction::FSub: 107 : : assert(Operands.size() == 2); 108 : 6 : ++FSubChanged; 109 : 6 : NewI = builder.CreateFSub(Operands[0], Operands[1]); 110 : 6 : break; 111 : 0 : case Instruction::FMul: 112 : : assert(Operands.size() == 2); 113 : 0 : ++FMulChanged; 114 : 0 : NewI = builder.CreateFMul(Operands[0], Operands[1]); 115 : 0 : break; 116 : 0 : case Instruction::FDiv: 117 : : assert(Operands.size() == 2); 118 : 0 : ++FDivChanged; 119 : 0 : NewI = builder.CreateFDiv(Operands[0], Operands[1]); 120 : 0 : break; 121 : 0 : case Instruction::FRem: 122 : : assert(Operands.size() == 2); 123 : 0 : ++FRemChanged; 124 : 0 : NewI = builder.CreateFRem(Operands[0], Operands[1]); 125 : 0 : break; 126 : 10 : case Instruction::FCmp: 127 : : assert(Operands.size() == 2); 128 : 10 : ++FCmpChanged; 129 : 20 : NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(), 130 : 10 : Operands[0], Operands[1]); 131 : 10 : break; 132 : 0 : default: 133 : 0 : abort(); 134 : : } 135 : 20 : cast<Instruction>(NewI)->copyMetadata(I); 136 : 20 : cast<Instruction>(NewI)->copyFastMathFlags(&I); 137 [ + + ]: 20 : if (NewI->getType() != I.getType()) { 138 : 10 : ++TotalTrunc; 139 : 10 : NewI = builder.CreateFPTrunc(NewI, I.getType()); 140 : : } 141 : 20 : I.replaceAllUsesWith(NewI); 142 : 20 : erase.push_back(&I); 143 : : } 144 : : } 145 : : } 146 : : 147 [ + + ]: 137811 : if (erase.size() > 0) { 148 [ + + ]: 30 : for (auto V : erase) 149 : 20 : V->eraseFromParent(); 150 : : assert(!verifyFunction(F, &errs())); 151 : 10 : return true; 152 : : } 153 : : else 154 : 137801 : return false; 155 : : } 156 : : 157 : : } // end anonymous namespace 158 : : 159 : 0 : PreservedAnalyses DemoteFloat16::run(Function &F, FunctionAnalysisManager &AM) 160 : : { 161 [ # # ]: 0 : if (demoteFloat16(F)) { 162 : 0 : return PreservedAnalyses::allInSet<CFGAnalyses>(); 163 : : } 164 : 0 : return PreservedAnalyses::all(); 165 : : } 166 : : 167 : : namespace { 168 : : 169 : : struct DemoteFloat16Legacy : public FunctionPass { 170 : : static char ID; 171 : 17 : DemoteFloat16Legacy() : FunctionPass(ID){}; 172 : : 173 : : private: 174 : 137811 : bool runOnFunction(Function &F) override { 175 : 137811 : return demoteFloat16(F); 176 : : } 177 : : }; 178 : : 179 : : char DemoteFloat16Legacy::ID = 0; 180 : : static RegisterPass<DemoteFloat16Legacy> 181 : : Y("DemoteFloat16", 182 : : "Demote Float16 operations to Float32 equivalents.", 183 : : false, 184 : : false); 185 : : } // end anonymous namespac 186 : : 187 : 17 : Pass *createDemoteFloat16Pass() 188 : : { 189 : 17 : return new DemoteFloat16Legacy(); 190 : : } 191 : : 192 : 0 : extern "C" JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass_impl(LLVMPassManagerRef PM) 193 : : { 194 : 0 : unwrap(PM)->add(createDemoteFloat16Pass()); 195 : 0 : }