#include "llvm/IR/Instruction.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopUnrollAnalyzer.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/IRBuilder.h"

using namespace llvm;

namespace {
struct OurLoopUnrollingPass : public LoopPass {
  std::vector<BasicBlock *> LoopBasicBlocks;
  std::unordered_map<Value *, Value *> VariablesMap;
  Value *LoopCounter, *LoopBound;
  int BoundValue;

  static char ID; // Pass identification, replacement for typeid
  OurLoopUnrollingPass() : LoopPass(ID) {}

  void MapVariables(Loop *L)
  {
    Function *F = L->getHeader()->getParent();
    for (BasicBlock &BB : *F) {
      for (Instruction &I : BB) {
        if (isa<LoadInst>(&I)) {
          VariablesMap[&I] = I.getOperand(0);
        }
      }
    }
  }

  void findLoopCounterAndBound(Loop *L)
  {
    for (Instruction &I : *L->getHeader()) {
      if (isa<ICmpInst>(&I)) {
        LoopCounter = VariablesMap[I.getOperand(0)];
        LoopBound = VariablesMap[I.getOperand(1)];
        if (ConstantInt *ConstInt = dyn_cast<ConstantInt>(LoopBound)) {
          BoundValue = ConstInt->getSExtValue();
        }
      }
    }
  }

  void partialUnrolling(Loop *L)
  {
    std::vector<Instruction *> LoopInstructions;
    std::unordered_map<Value *, Value *> Mapping;
    std::unordered_map<Value *, Value *> LoadMapping;

    BasicBlock *LoopBody = LoopBasicBlocks[1];

    for (Instruction &I : *LoopBody) {
      if (!I.isTerminator()) {
        LoopInstructions.push_back(&I);
      }
    }

    int Factor = 3;

    Instruction *Copy;

    for (int i = 1; i <= Factor; i++) {
      for (Instruction *I : LoopInstructions) {
        Copy = I->clone();
        Copy->insertBefore(LoopBody->getTerminator());

        if (isa<LoadInst>(Copy) && Copy->getOperand(0) == LoopCounter) {
          Instruction *Add = (Instruction *) BinaryOperator::CreateAdd(Copy,
                             ConstantInt::get(Type::getInt32Ty(Copy->getContext()), i));

          Add->insertAfter(Copy);
          LoadMapping[Copy] = Add;
        }

        Mapping[I] = Copy;

        for (size_t j = 0; j < Copy->getNumOperands(); j++) {
          if (Mapping.find(Copy->getOperand(j)) != Mapping.end()) {
            Copy->setOperand(j, Mapping[Copy->getOperand(j)]);
          }
          if (LoadMapping.find(Copy->getOperand(j)) != LoadMapping.end()) {
            Copy->setOperand(j, LoadMapping[Copy->getOperand(j)]);
          }
        }
      }
    }
  }

  void unrollLoop(Loop *L)
  {
    partialUnrolling(L);
  }

  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
    LoopBasicBlocks = L->getBlocksVector();
    MapVariables(L);
    findLoopCounterAndBound(L);
    unrollLoop(L);

    return true;
  }
}; // end of struct OurLoopUnrollingPass
}  // end of anonymous namespace

char OurLoopUnrollingPass::ID = 0;
static RegisterPass<OurLoopUnrollingPass> X("loop-unrolling", "",
                                       false /* Only looks at CFG */,
                                       false /* Analysis Pass */);