#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/raw_ostream.h"

#include "llvm/IR/LegacyPassManager.h"

using namespace llvm;

namespace {
struct OurCSE : public FunctionPass {
  std::unordered_map<Value *, Value *> VariablesMap;
  std::vector<Instruction *> SavedInstructions;

  static char ID;
  OurCSE() : FunctionPass(ID) {}

  void mapVariables(Function &F) {
    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        if (isa<LoadInst>(&I)) {
          VariablesMap[&I] = I.getOperand(0);
        }
      }
    }
  }

  bool shouldSave(Instruction &I) {
    if (isa<CallInst>(&I)) {
      return true;
    }

    if (BinaryOperator *BinaryOp = dyn_cast<BinaryOperator>(&I)) {
      if (isa<AddOperator>(BinaryOp) || isa<SubOperator>(BinaryOp) || isa<MulOperator>(BinaryOp) ||
          isa<SDivOperator>(BinaryOp)) {
        return true;
      }
    }

    return false;
  }

  bool haveTheSameType(Instruction *I1, Instruction *I2) {
    return I1->getOpcode() == I2->getOpcode();
  }

  bool isConstantInt(Value *Operand) {
    return isa<ConstantInt>(Operand);
  }

  int getValue(Value *Operand) {
    ConstantInt *Const = dyn_cast<ConstantInt>(Operand);
    return Const->getSExtValue();
  }

  bool haveTheSameOperandsForCall(Instruction *I1, Instruction *I2) {
    if (I1->getNumOperands() != I2->getNumOperands()) {
      return false;
    }

    size_t n = I1->getNumOperands();
    for (size_t i = 0; i < n - 1; i++) {
      if (isConstantInt(I1->getOperand(i))) {
        if (!isConstantInt(I2->getOperand(i))) {
          return false;
        }
        if (getValue(I1->getOperand(i)) != getValue(I2->getOperand(i))) {
          return false;
        }
      }
      else {
        if (isConstantInt(I2->getOperand(i))) {
          return false;
        }
      }

      if (VariablesMap[I1->getOperand(i)] != VariablesMap[I2->getOperand(i)]) {
        return false;
      }
    }

    return I1->getOperand(n - 1) == I2->getOperand(n - 1);
  }

  bool haveTheSameOperands(Instruction *I1, Instruction *I2) {
    if (isConstantInt(I1->getOperand(0)) && isConstantInt(I2->getOperand(0)) &&
        isConstantInt(I1->getOperand(1)) && isConstantInt(I2->getOperand(1))) {
      return false;
    }

    if ((isConstantInt(I1->getOperand(0)) && !isConstantInt(I2->getOperand(0))) ||
        (!isConstantInt(I1->getOperand(0)) && isConstantInt(I2->getOperand(0)))) {
      return false;
    }

    if ((isConstantInt(I1->getOperand(1)) && !isConstantInt(I2->getOperand(1))) ||
        (!isConstantInt(I1->getOperand(1)) && isConstantInt(I2->getOperand(1)))) {
      return false;
    }

    if (isConstantInt(I1->getOperand(0)) && isConstantInt(I2->getOperand(0))) {
      if (getValue(I1->getOperand(0)) != getValue(I2->getOperand(0))) {
        return false;
      }
      return VariablesMap[I1->getOperand(1)] == VariablesMap[I2->getOperand(1)];
    }

    if (isConstantInt(I1->getOperand(1)) && isConstantInt(I2->getOperand(1))) {
      if (getValue(I1->getOperand(1)) != getValue(I2->getOperand(1))) {
        return false;
      }
      return VariablesMap[I1->getOperand(0)] == VariablesMap[I2->getOperand(0)];
    }

    return VariablesMap[I1->getOperand(0)] == VariablesMap[I2->getOperand(0)] &&
           VariablesMap[I1->getOperand(1)] == VariablesMap[I2->getOperand(1)];
  }

  bool haveTheSameOperandsCommutative(Instruction *I1, Instruction *I2) {
    if (isConstantInt(I1->getOperand(0)) && isConstantInt(I2->getOperand(0)) &&
        isConstantInt(I1->getOperand(1)) && isConstantInt(I2->getOperand(1))) {
      return false;
    }

    Value *Operand11 = I1->getOperand(0), *Operand12 = I1->getOperand(1),
          *Operand21 = I2->getOperand(0), *Operand22 = I2->getOperand(1);

    if (isConstantInt(Operand11)) {
      if (!isConstantInt(Operand21) && !isConstantInt(Operand22)) {
        return false;
      }

      if (isConstantInt(Operand21)) {
       if (getValue(Operand11) != getValue(Operand21)) {
         return false;
       }
       else {
         return VariablesMap[Operand12] == VariablesMap[Operand22];
       }
      }

      if (isConstantInt(Operand22)) {
        if (getValue(Operand11) != getValue(Operand22)) {
          return false;
        } else {
          return VariablesMap[Operand12] == VariablesMap[Operand21];
        }
      }
    }

    if (isConstantInt(Operand12)) {
      if (!isConstantInt(Operand21) && !isConstantInt(Operand22)) {
        return false;
      }

      if (isConstantInt(Operand21)) {
        if (getValue(Operand12) != getValue(Operand21)) {
          return false;
        }
        else {
          return VariablesMap[Operand11] == VariablesMap[Operand22];
        }
      }

      if (isConstantInt(Operand22)) {
        if (getValue(Operand12) != getValue(Operand22)) {
          return false;
        }
        else {
          return VariablesMap[Operand11] == VariablesMap[Operand21];
        }
      }
    }

    return (VariablesMap[Operand11] == VariablesMap[Operand21] && VariablesMap[Operand12] == VariablesMap[Operand22]) ||
           (VariablesMap[Operand11] == VariablesMap[Operand22] && VariablesMap[Operand12] == VariablesMap[Operand21]);
  }

  bool isCommutative(Instruction &I) {
    return isa<AddOperator>(&I) || isa<MulOperator>(&I);
  }

  Instruction *alreadySaved(Instruction &I) {
    for (Instruction *SavedInstr : SavedInstructions) {
      if (haveTheSameType(&I, SavedInstr)) {
        if (isCommutative(I)) {
          if (haveTheSameOperandsCommutative(&I, SavedInstr)) {
            return SavedInstr;
          }
        }
        else if (isa<CallInst> (&I)) {
          CallInst *Call1 = dyn_cast<CallInst>(&I);
          CallInst *Call2 = dyn_cast<CallInst>(SavedInstr);
          if (Call1->getCalledFunction() != Call2->getCalledFunction()) {
            return nullptr;
          }
          if (haveTheSameOperandsForCall(&I, SavedInstr)) {
            return SavedInstr;
          }
        }
        else if (haveTheSameOperands(&I, SavedInstr)) {
          return SavedInstr;
        }
      }
    }

    return nullptr;
  }

  void maybeDelete(Instruction &I) {
    std::vector<Instruction *> InstructionsToDelete;

    for (Instruction *SavedInstr : SavedInstructions) {
      if (SavedInstr == I.getOperand(1)) {
        InstructionsToDelete.push_back(SavedInstr);
      }
      else {
        for (size_t i = 0; i < SavedInstr->getNumOperands(); i++) {
          if (I.getOperand(1) == VariablesMap[SavedInstr->getOperand(i)]) {
            InstructionsToDelete.push_back(SavedInstr);
          }
        }
      }
    }

    for (Instruction *Instr : InstructionsToDelete) {
      SavedInstructions.erase(std::find(SavedInstructions.begin(),
                                        SavedInstructions.end(),
                                        Instr));
    }
  }

  bool runOnFunction(Function &F) override {
    VariablesMap.clear();
    SavedInstructions.clear();

    mapVariables(F);

    Instruction *SavedInstruction;

    for (BasicBlock &BB : F) {
      for (Instruction &I : BB) {
        if (shouldSave(I)) {
          if((SavedInstruction = alreadySaved(I)) != nullptr) {
            I.replaceAllUsesWith(SavedInstruction);
          }
          else {
            SavedInstructions.push_back(&I);
          }
        }
        else if (isa<StoreInst>(&I)) {
          maybeDelete(I);
        }
      }
    }

    return true;
  }
};
}

char OurCSE::ID = 0;
static RegisterPass<OurCSE> X("our-cse", "Our simple CSE implementation",
                             false /* Only looks at CFG */,
                             false /* Analysis Pass */);