#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/IRBuilder.h"

using namespace llvm;

namespace {
struct OurSimpleStrengthReductionPass : public FunctionPass {
  std::vector<Instruction *>InstructionsToRemove;

  static char ID;
  OurSimpleStrengthReductionPass() : FunctionPass(ID) {}

  bool isConstantInt(Value *Operand)
  {
    return isa<ConstantInt>(Operand);
  }

  int getConstantValue(Value *Operand)
  {
    ConstantInt *Const = dyn_cast<ConstantInt>(Operand);
    return Const->getSExtValue();
  }

  bool isPowerOfTwo(int value)
  {
    return (value & std::numeric_limits<int>::max()) == value;
  }

  int getPowerOfTwo(int value) {
    int power = 0;
    int mask = 1;

    while (true) {
      if (mask & value) {
        return power;
      }
      power++;
      mask <<= 1;
    }
  }

  bool runOnFunction(Function &F) override {
    for (BasicBlock &BB : F) {
      for (Instruction &Instr : BB) {
        if (BinaryOperator *BinaryOp = dyn_cast<BinaryOperator>(&Instr)) {
          Value *LeftOperand = BinaryOp->getOperand(0);
          Value *RightOperand = BinaryOp->getOperand(1);
          IRBuilder<> Builder(Instr.getContext());

          if (isa<MulOperator>(BinaryOp)) {
            if (isConstantInt(LeftOperand) && !isConstantInt(RightOperand) && isPowerOfTwo(getConstantValue(LeftOperand))) {
              int power = getPowerOfTwo(getConstantValue(LeftOperand));
              Instruction *Shift = (Instruction *)Builder.CreateShl(RightOperand, power);
              Shift->insertAfter(&Instr);
              Instr.replaceAllUsesWith(Shift);
              InstructionsToRemove.push_back(&Instr);
            }
            else if (isConstantInt(RightOperand) && !isConstantInt(LeftOperand) && isPowerOfTwo(getConstantValue(RightOperand))) {
              int power = getPowerOfTwo(getConstantValue(RightOperand));
              Instruction *Shift = (Instruction *)Builder.CreateShl(LeftOperand, power);
              Shift->insertAfter(&Instr);
              Instr.replaceAllUsesWith(Shift);
              InstructionsToRemove.push_back(&Instr);
            }
          }
          else if (std::string(Instr.getOpcodeName()) == "srem") {
            if (!isConstantInt(LeftOperand) && isConstantInt(RightOperand) && isPowerOfTwo(getConstantValue(RightOperand))) {
              Instruction *And = (Instruction *) Builder.CreateAnd(LeftOperand, getConstantValue(RightOperand) - 1);
              And->insertAfter(&Instr);
              Instr.replaceAllUsesWith(And);
              InstructionsToRemove.push_back(&Instr);
            }
          }
        }
      }
    }

    for (Instruction *Instr : InstructionsToRemove) {
      Instr->eraseFromParent();
    }
    return true;
  }
};
}

char OurSimpleStrengthReductionPass::ID = 0;
static RegisterPass<OurSimpleStrengthReductionPass> X("our-s-r", "Our simple strength reduction",
                                  false /* Only looks at CFG */,
                                  false /* Analysis Pass */);