#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/IRBuilder.h"
#include "ConstantPropagationInstruction.h"
#include <unordered_set>

using namespace llvm;

namespace {
struct OurConstantPropagationPass : public FunctionPass {
  std::vector<Value *> Variables;
  std::vector<ConstantPropagationInstruction *> Instructions;

  static char ID;
  OurConstantPropagationPass() : FunctionPass(ID) {}

  void createConstantPropagationInstructions(Function &F) {
    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        ConstantPropagationInstruction *Instr = new ConstantPropagationInstruction(&I);
        Instructions.push_back(Instr);

        Instruction *Predecessor = I.getPrevNonDebugInstruction();
        if (Predecessor == nullptr) {
          for (BasicBlock *BBPredecessor : predecessors(&BB)) {
            Instr->addPredecessor(*std::find_if(Instructions.begin(), Instructions.end(),
            [BBPredecessor](auto &CPI){
                                return CPI->getInstruction() == BBPredecessor->getTerminator(); }));
          }
        }
        else {
          Instr->addPredecessor(Instructions[Instructions.size() - 2]);
        }
      }
    }
  }

  void findVariables(Function &F) {
    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        if (isa<AllocaInst>(&I)) {
          Variables.push_back(&I);
        }
      }
    }
  }

  void setVariables()
  {
    for (ConstantPropagationInstruction *CPI : Instructions) {
      CPI->addVariables(Variables);
    }
  }

  bool checkRuleOne(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    for (ConstantPropagationInstruction *Predecessor : CPI->getPredecessors()) {
      if (Predecessor->getStatusAfter(Variable) == Top) {
        return CPI->getStatusBefore(Variable) == Top;
      }
    }

    return true;
  }

  void applyRuleOne(ConstantPropagationInstruction *CPI, Value *Variable)
  {
      CPI->setStatusBefore(Variable, Top);
  }

  bool checkRuleTwo(ConstantPropagationInstruction *CPI, Value *Variable) {
    std::unordered_set<int> Values;

    for (ConstantPropagationInstruction *Predecessor : CPI->getPredecessors()) {
      if (Predecessor->getStatusAfter(Variable) == Const) {
        Values.insert(Predecessor->getValueAfter(Variable));
      }
    }

    if (Values.size() > 1) {
      return CPI->getStatusBefore(Variable) == Top;
    }

    return true;
  }

  void applyRuleTwo(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    CPI->setStatusBefore(Variable, Top);
  }

  bool checkRuleThree(ConstantPropagationInstruction *CPI, Value *Variable) {
    std::unordered_set<int> Values;

    for (ConstantPropagationInstruction *Predecessor : CPI->getPredecessors()) {
      if (Predecessor->getStatusAfter(Variable) == Const) {
        Values.insert(Predecessor->getValueAfter(Variable));
      }
      else if (Predecessor->getStatusAfter(Variable) == Top) {
        return true;
      }
    }

    if (Values.size() == 1) {
      return CPI->getStatusBefore(Variable) == Const && CPI->getValueBefore(Variable) == *Values.begin();
    }

    return true;
  }

  void applyRuleThree(ConstantPropagationInstruction *CPI, Value *Variable, int Value)
  {
    CPI->setStatusBefore(Variable, Const, Value);
  }

  bool checkRuleFour(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    for (ConstantPropagationInstruction *Predecessor : CPI->getPredecessors()) {
      if (Predecessor->getStatusAfter(Variable) == Top || Predecessor->getStatusAfter(Variable) == Const) {
        return true;
      }
    }

    if (CPI->getPredecessors().size() == 0) {
      return true;
    }

    return CPI->getStatusBefore(Variable) == Bottom;
  }

  void applyRuleFour(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    CPI->setStatusBefore(Variable, Bottom);
  }

  bool checkRuleFive(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    if (CPI->getStatusBefore(Variable) == Bottom) {
      return CPI->getStatusAfter(Variable) == Bottom;
    }

    return true;
  }

  void applyRuleFive(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    CPI->setStatusAfter(Variable, Bottom);
  }

  bool checkRuleSix(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    if (isa<StoreInst>(CPI->getInstruction()) && CPI->getInstruction()->getOperand(1) == Variable) {
      if (ConstantInt *ConstInt = dyn_cast<ConstantInt>(CPI->getInstruction()->getOperand(0))) {
        return CPI->getStatusAfter(Variable) == Const &&
               CPI->getValueAfter(Variable) == ConstInt->getSExtValue();
      }
    }

    return true;
  }

  void applyRuleSix(ConstantPropagationInstruction *CPI, Value *Variable, int Value)
  {
    CPI->setStatusAfter(Variable, Const, Value);
  }

  bool checkRuleSeven(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    if (isa<StoreInst>(CPI->getInstruction()) && CPI->getInstruction()->getOperand(1) == Variable) {
      if (!isa<ConstantInt>(CPI->getInstruction()->getOperand(0))) {
        return CPI->getStatusAfter(Variable) == Top;
      }
    }

    return true;
  }

  void applyRuleSeven(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    CPI->setStatusAfter(Variable, Top);
  }

  bool checkRuleEight(ConstantPropagationInstruction *CPI, Value *Variable) {
    if (isa<StoreInst>(CPI->getInstruction())) {
      if (CPI->getInstruction()->getOperand(1) == Variable) {
        return true;
      }
    }

    return CPI->getStatusBefore(Variable) == CPI->getStatusAfter(Variable);
  }

  void applyRuleEight(ConstantPropagationInstruction *CPI, Value *Variable)
  {
    CPI->setStatusAfter(Variable, CPI->getStatusBefore(Variable), CPI->getValueBefore(Variable));
  }

  void setStatusForFirstInstruction()
  {
    for (Value *Variable : Variables) {
      Instructions.front()->setStatusBefore(Variable, Top);
    }
  }

  void propagateVariable(Value *Variable)
  {
    bool RuleApplied;

    while (true) {
      RuleApplied = false;

      for (ConstantPropagationInstruction *CPI : Instructions) {
        if (!checkRuleOne(CPI, Variable)) {
          applyRuleOne(CPI, Variable);
          RuleApplied = true;
        } else if (!checkRuleTwo(CPI, Variable)) {
          applyRuleTwo(CPI, Variable);
          RuleApplied = true;
        } else if (!checkRuleThree(CPI, Variable)) {
          int Value;
          for (ConstantPropagationInstruction *Predecessor :
               CPI->getPredecessors()) {
            if (Predecessor->getStatusAfter(Variable) == Const) {
              Value = Predecessor->getValueAfter(Variable);
              break;
            }
          }
          applyRuleThree(CPI, Variable, Value);
          RuleApplied = true;
        } else if (!checkRuleFour(CPI, Variable)) {
          applyRuleFour(CPI, Variable);
          RuleApplied = true;
        } else if (!checkRuleFive(CPI, Variable)) {
          applyRuleFive(CPI, Variable);
          RuleApplied = true;
        } else if (!checkRuleSix(CPI, Variable)) {
          ConstantInt *ConstInt =
              dyn_cast<ConstantInt>(CPI->getInstruction()->getOperand(0));
          applyRuleSix(CPI, Variable, ConstInt->getSExtValue());
          RuleApplied = true;
        } else if (!checkRuleSeven(CPI, Variable)) {
          applyRuleSeven(CPI, Variable);
          RuleApplied = true;
        } else if (!checkRuleEight(CPI, Variable)) {
          applyRuleEight(CPI, Variable);
          RuleApplied = true;
        }
      }

      if (!RuleApplied) {
        break;
      }
    }
  }

  void runAlgorithm()
  {
    for (Value *Variable : Variables) {
      propagateVariable(Variable);
    }
  }

  void modifyIR()
  {
    std::unordered_map<Value *, Value *>VariablesMap;

    for (ConstantPropagationInstruction *CPI : Instructions) {
      if (isa<LoadInst>(CPI->getInstruction())) {
        VariablesMap[CPI->getInstruction()] = CPI->getInstruction()->getOperand(0);
      }
    }

    for (ConstantPropagationInstruction *CPI : Instructions) {

      if (isa<StoreInst>(CPI->getInstruction())) {
        Value *Operand = VariablesMap[CPI->getInstruction()->getOperand(0)];
        if (CPI->getStatusBefore(Operand) == Const) {
          int Value = CPI->getValueBefore(Operand);
          ConstantInt *ConstInt = ConstantInt::get(Type::getInt32Ty(CPI->getInstruction()->getContext()), Value);
          CPI->getInstruction()->getOperand(0)->replaceAllUsesWith(ConstInt);
        }
      }
      else if (isa<BinaryOperator>(CPI->getInstruction()) || isa<ICmpInst>(CPI->getInstruction())) {
        Value *Lhs = CPI->getInstruction()->getOperand(0), *Rhs = CPI->getInstruction()->getOperand(1);

        if (VariablesMap[Lhs] != nullptr && CPI->getStatusBefore(VariablesMap[Lhs]) == Const) {
          Lhs->replaceAllUsesWith(ConstantInt::get(Type::getInt32Ty(CPI->getInstruction()->getContext()),
              CPI->getValueBefore(VariablesMap[Lhs])));
        }

        if (VariablesMap[Rhs] != nullptr && CPI->getStatusBefore(VariablesMap[Rhs]) == Const) {
          Rhs->replaceAllUsesWith(ConstantInt::get(Type::getInt32Ty(CPI->getInstruction()->getContext()),
                                                   CPI->getValueBefore(VariablesMap[Rhs])));
        }
      }
    }
  }

  bool runOnFunction(Function &F) override {
    createConstantPropagationInstructions(F);
    findVariables(F);
    setVariables();
    setStatusForFirstInstruction();
    runAlgorithm();
    modifyIR();

    return true;
  }
};
}

char OurConstantPropagationPass::ID = 0;
static RegisterPass<OurConstantPropagationPass> X("propagation", "Our simple constant propagation pass",
                             false /* Only looks at CFG */,
                             false /* Analysis Pass */);