#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/IRBuilder.h"

using namespace llvm;

namespace {
struct OurStrengthReduction : public FunctionPass {
  std::vector<Instruction *> InstructionsToRemove;

  static char ID;
  OurStrengthReduction() : FunctionPass(ID) {}

  bool isConstInt(Value *Operand)
  {
    return isa<ConstantInt>(Operand);
  }

  int getValueFromConstInt(Value *Operand)
  {
    ConstantInt *ConstInt = dyn_cast<ConstantInt>(Operand);
    return ConstInt->getSExtValue();
  }

  bool isPowerOfTwo(int value)
  {
    return (value & std::numeric_limits<int>::max()) == value;
  }

  int powerOfTwo(int value)
  {
    int mask = 1;
    int power = 0;

    while(true) {
      if ((value & mask) == value) {
        return power;
      }
      mask <<= 1;
      power++;
    }
  }

  bool runOnFunction(Function &F) override {
    for (BasicBlock &BB : F) {
      for (Instruction &Instr : BB) {
        if (BinaryOperator *BinaryOp = dyn_cast<BinaryOperator>(&Instr)) {
          Value *LeftOperand = BinaryOp->getOperand(0);
          Value *RightOperand = BinaryOp->getOperand(1);
          IRBuilder Builder(Instr.getContext());
          if (MulOperator *Mul = dyn_cast<MulOperator>(BinaryOp)) {
            (void) Mul;
            if (isConstInt(LeftOperand) && isPowerOfTwo(getValueFromConstInt(LeftOperand)) && !isConstInt(RightOperand)) {
              int power = powerOfTwo(getValueFromConstInt(LeftOperand));
              Instruction *LeftShift = (Instruction *) Builder.CreateShl(RightOperand, power);
              LeftShift->insertAfter(&Instr);
              Instr.replaceAllUsesWith(LeftShift);
              InstructionsToRemove.push_back(&Instr);
            }
            else if (isConstInt(RightOperand) && isPowerOfTwo(getValueFromConstInt(RightOperand)) && !isConstInt(LeftOperand)) {
              int power = powerOfTwo(getValueFromConstInt(RightOperand));
              Instruction *LeftShift = (Instruction *) Builder.CreateShl(LeftOperand, power);
              LeftShift->insertAfter(&Instr);
              Instr.replaceAllUsesWith(LeftShift);
              InstructionsToRemove.push_back(&Instr);
            }
          }
          else if (std::string(Instr.getOpcodeName()) == "srem" && isConstInt(RightOperand) &&
                   !isConstInt(LeftOperand) && isPowerOfTwo(getValueFromConstInt(RightOperand))) {
              Instruction *And = (Instruction *) Builder.CreateAnd(LeftOperand, getValueFromConstInt(RightOperand) - 1);
              And->insertAfter(&Instr);
              Instr.replaceAllUsesWith(And);
              InstructionsToRemove.push_back(&Instr);
          }
        }
      }
    }

    for (Instruction *Instr : InstructionsToRemove) {
      Instr->eraseFromParent();
    }

    return true;
  }
};
}

char OurStrengthReduction::ID = 0;
static RegisterPass<OurStrengthReduction> X("our-s-r", "Documentation",
                                  false /* Only looks at CFG */,
                                  false /* Analysis Pass */);