diff options
Diffstat (limited to 'contrib/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp')
-rw-r--r-- | contrib/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/contrib/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp b/contrib/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp new file mode 100644 index 0000000..63c3699 --- /dev/null +++ b/contrib/llvm/lib/CodeGen/ForwardControlFlowIntegrity.cpp @@ -0,0 +1,374 @@ +//===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===// +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// \brief A pass that instruments code with fast checks for indirect calls and +/// hooks for a function to check violations. +/// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "cfi" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/JumpInstrTableInfo.h" +#include "llvm/CodeGen/ForwardControlFlowIntegrity.h" +#include "llvm/CodeGen/JumpInstrTables.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +STATISTIC(NumCFIIndirectCalls, + "Number of indirect call sites rewritten by the CFI pass"); + +char ForwardControlFlowIntegrity::ID = 0; +INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi", + "Control-Flow Integrity", true, true) +INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo); +INITIALIZE_PASS_DEPENDENCY(JumpInstrTables); +INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi", + "Control-Flow Integrity", true, true) + +ModulePass *llvm::createForwardControlFlowIntegrityPass() { + return new ForwardControlFlowIntegrity(); +} + +ModulePass *llvm::createForwardControlFlowIntegrityPass( + JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, + StringRef CFIFuncName) { + return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing, + CFIFuncName); +} + +// Checks to see if a given CallSite is making an indirect call, including +// cases where the indirect call is made through a bitcast. +static bool isIndirectCall(CallSite &CS) { + if (CS.getCalledFunction()) + return false; + + // Check the value to see if it is merely a bitcast of a function. In + // this case, it will translate to a direct function call in the resulting + // assembly, so we won't treat it as an indirect call here. + const Value *V = CS.getCalledValue(); + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + return !(CE->isCast() && isa<Function>(CE->getOperand(0))); + } + + // Otherwise, since we know it's a call, it must be an indirect call + return true; +} + +static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning"; + +ForwardControlFlowIntegrity::ForwardControlFlowIntegrity() + : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single), + CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") { + initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); +} + +ForwardControlFlowIntegrity::ForwardControlFlowIntegrity( + JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, + std::string CFIFuncName) + : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType), + CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) { + initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); +} + +ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {} + +void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<JumpInstrTableInfo>(); + AU.addRequired<JumpInstrTables>(); +} + +void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) { + // To get the indirect calls, we iterate over all functions and iterate over + // the list of basic blocks in each. We extract a total list of indirect calls + // before modifying any of them, since our modifications will modify the list + // of basic blocks. + for (Function &F : M) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + CallSite CS(&I); + if (!(CS && isIndirectCall(CS))) + continue; + + Value *CalledValue = CS.getCalledValue(); + + // Don't rewrite this instruction if the indirect call is actually just + // inline assembly, since our transformation will generate an invalid + // module in that case. + if (isa<InlineAsm>(CalledValue)) + continue; + + IndirectCalls.push_back(&I); + } + } + } +} + +void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M, + CFITables &CFIT) { + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + for (Instruction *I : IndirectCalls) { + CallSite CS(I); + Value *CalledValue = CS.getCalledValue(); + + // Get the function type for this call and look it up in the tables. + Type *VTy = CalledValue->getType(); + PointerType *PTy = dyn_cast<PointerType>(VTy); + Type *EltTy = PTy->getElementType(); + FunctionType *FunTy = dyn_cast<FunctionType>(EltTy); + FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy); + ++NumCFIIndirectCalls; + Constant *JumpTableStart = nullptr; + Constant *JumpTableMask = nullptr; + Constant *JumpTableSize = nullptr; + + // Some call sites have function types that don't correspond to any + // address-taken function in the module. This happens when function pointers + // are passed in from external code. + auto it = CFIT.find(TransformedTy); + if (it == CFIT.end()) { + // In this case, make sure that the function pointer will change by + // setting the mask and the start to be 0 so that the transformed + // function is 0. + JumpTableStart = ConstantInt::get(Int64Ty, 0); + JumpTableMask = ConstantInt::get(Int64Ty, 0); + JumpTableSize = ConstantInt::get(Int64Ty, 0); + } else { + JumpTableStart = it->second.StartValue; + JumpTableMask = it->second.MaskValue; + JumpTableSize = it->second.Size; + } + + rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask, + JumpTableSize); + } + + return; +} + +bool ForwardControlFlowIntegrity::runOnModule(Module &M) { + JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>(); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); + + // JumpInstrTableInfo stores information about the alignment of each entry. + // The alignment returned by JumpInstrTableInfo is alignment in bytes, not + // in the exponent. + ByteAlignment = JITI->entryByteAlignment(); + LogByteAlignment = llvm::Log2_64(ByteAlignment); + + // Set up tables for control-flow integrity based on information about the + // jump-instruction tables. + CFITables CFIT; + for (const auto &KV : JITI->getTables()) { + uint64_t Size = static_cast<uint64_t>(KV.second.size()); + uint64_t TableSize = NextPowerOf2(Size); + + int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment; + Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue); + Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size); + + // The base of the table is defined to be the first jumptable function in + // the table. + Function *First = KV.second.begin()->second; + Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy); + CFIT[KV.first].StartValue = JumpTableStartValue; + CFIT[KV.first].MaskValue = JumpTableMaskValue; + CFIT[KV.first].Size = JumpTableSize; + } + + if (CFIT.empty()) + return false; + + getIndirectCalls(M); + + if (!CFIEnforcing) { + addWarningFunction(M); + } + + // Update the instructions with the check and the indirect jump through our + // table. + updateIndirectCalls(M, CFIT); + + return true; +} + +void ForwardControlFlowIntegrity::addWarningFunction(Module &M) { + PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext()); + + // Get the type of the Warning Function: void (i8*, i8*), + // where the first argument is the name of the function in which the violation + // occurs, and the second is the function pointer that violates CFI. + SmallVector<Type *, 2> WarningFunArgs; + WarningFunArgs.push_back(CharPtrTy); + WarningFunArgs.push_back(CharPtrTy); + FunctionType *WarningFunTy = + FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false); + + if (!CFIFuncName.empty()) { + Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy); + if (!FailureFun) + report_fatal_error("Could not get or insert the function specified by" + " -cfi-func-name"); + } else { + // The default warning function swallows the warning and lets the call + // continue, since there's no generic way for it to print out this + // information. + Function *WarningFun = M.getFunction(cfi_failure_func_name); + if (!WarningFun) { + WarningFun = + Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage, + cfi_failure_func_name, &M); + } + + BasicBlock *Entry = + BasicBlock::Create(M.getContext(), "entry", WarningFun, 0); + ReturnInst::Create(M.getContext(), Entry); + } +} + +void ForwardControlFlowIntegrity::rewriteFunctionPointer( + Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart, + Constant *JumpTableMask, Constant *JumpTableSize) { + IRBuilder<> TempBuilder(I); + + Type *OrigFunType = FunPtr->getType(); + + BasicBlock *CurBB = cast<BasicBlock>(I->getParent()); + Function *CurF = cast<Function>(CurBB->getParent()); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + + Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty); + Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty); + + Value *NewFunPtr = nullptr; + Value *Check = nullptr; + switch (CFIType) { + case CFIntegrity::Sub: { + // This is the subtract, mask, and add version. + // Subtract from the base. + Value *Sub = TempBuilder.CreateSub(TI, TStartInt); + + // Mask the difference to force this to be a table offset. + Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask); + + // Add it back to the base. + Value *Result = TempBuilder.CreateAdd(And, TStartInt); + + // Convert it back into a function pointer that we can call. + NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); + break; + } + case CFIntegrity::Ror: { + // This is the subtract and rotate version. + // Rotate right by the alignment value. The optimizer should recognize + // this sequence as a rotation. + + // This cast is safe, since unsigned is always a subset of uint64_t. + uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment); + Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64); + Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64); + + // Subtract from the base. + Value *Sub = TempBuilder.CreateSub(TI, TStartInt); + + // Create the equivalent of a rotate-right instruction. + Value *Shr = TempBuilder.CreateLShr(Sub, RightShift); + Value *Shl = TempBuilder.CreateShl(Sub, LeftShift); + Value *Or = TempBuilder.CreateOr(Shr, Shl); + + // Perform unsigned comparison to check for inclusion in the table. + Check = TempBuilder.CreateICmpULT(Or, JumpTableSize); + NewFunPtr = FunPtr; + break; + } + case CFIntegrity::Add: { + // This is the mask and add version. + // Mask the function pointer to turn it into an offset into the table. + Value *And = TempBuilder.CreateAnd(TI, JumpTableMask); + + // Then or this offset to the base and get the pointer value. + Value *Result = TempBuilder.CreateAdd(And, TStartInt); + + // Convert it back into a function pointer that we can call. + NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); + break; + } + } + + if (!CFIEnforcing) { + // If a check hasn't been added (in the rotation version), then check to see + // if it's the same as the original function. This check determines whether + // or not we call the CFI failure function. + if (!Check) + Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr); + BasicBlock *InvalidPtrBlock = + BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0); + BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I); + + // Remove the unconditional branch that connects the two blocks. + TerminatorInst *TermInst = CurBB->getTerminator(); + TermInst->eraseFromParent(); + + // Add a conditional branch that depends on the Check above. + BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB); + + // Call the warning function for this pointer, then continue. + Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock); + insertWarning(M, InvalidPtrBlock, BI, FunPtr); + } else { + // Modify the instruction to call this value. + CallSite CS(I); + CS.setCalledFunction(NewFunPtr); + } +} + +void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block, + Instruction *I, Value *FunPtr) { + Function *ParentFun = cast<Function>(Block->getParent()); + + // Get the function to call right before the instruction. + Function *WarningFun = nullptr; + if (CFIFuncName.empty()) { + WarningFun = M.getFunction(cfi_failure_func_name); + } else { + WarningFun = M.getFunction(CFIFuncName); + } + + assert(WarningFun && "Could not find the CFI failure function"); + + Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); + + IRBuilder<> WarningInserter(I); + // Create a mergeable GlobalVariable containing the name of the function. + Value *ParentNameGV = + WarningInserter.CreateGlobalString(ParentFun->getName()); + Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy); + Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy); + WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr); +} |