#include "llvm/IR/Instruction.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/LoopPeel.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/SizeOpts.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopUnrollAnalyzer.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Pass.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/IR/IRBuilder.h"
#include <algorithm>
#include <unordered_set>

using namespace llvm;

namespace {
struct OurLoopFissionPass : public LoopPass {
  std::vector<BasicBlock *> LoopBasicBlocks;
  
  static char ID; // Pass identification, replacement for typeid
  OurLoopFissionPass() : LoopPass(ID) {}

  BasicBlock *copyLoop(Loop *L)
  {
    BasicBlock *Exit = L->getExitBlock();
    std::vector<BasicBlock *> LoopBasicBlocksCopy;
    std::unordered_map<Value *, Value *> Mapping;
    std::unordered_map<BasicBlock *, BasicBlock *> BasicBlocksMapping;
    Instruction *Clone;
    IRBuilder<> Builder(Exit);
    BasicBlock *NewBasicBlock;

    for (BasicBlock *BB : LoopBasicBlocks) {
      NewBasicBlock = BasicBlock::Create(Exit->getContext(), "", Exit->getParent(), Exit);
      LoopBasicBlocksCopy.push_back(NewBasicBlock);
      BasicBlocksMapping[BB] = NewBasicBlock;
    }

    for (BasicBlock *BB : LoopBasicBlocks) {
      NewBasicBlock = BasicBlocksMapping[BB];
      Builder.SetInsertPoint(NewBasicBlock);
      for (Instruction &I : *BB) {
        Clone = I.clone();
        Mapping[&I] = Clone;
        Builder.Insert(Clone);

        for (size_t i = 0; i < Clone->getNumOperands(); i++) {
          if (Mapping.find(Clone->getOperand(i)) != Mapping.end()) {
            Clone->setOperand(i, Mapping[Clone->getOperand(i)]);
          }
        }
      }
    }

    for (BasicBlock *BB : LoopBasicBlocksCopy) {
      for (size_t i = 0; i < BB->getTerminator()->getNumSuccessors(); i++) {
        BasicBlock *Successor = BB->getTerminator()->getSuccessor(i);
        if (BasicBlocksMapping.find(Successor) != BasicBlocksMapping.end()) {
          BB->getTerminator()->setSuccessor(i, BasicBlocksMapping[Successor]);
        }
      }
    }

    std::unordered_set<BasicBlock *> BlocksToDelete;
    BasicBlock *BlockToStart = findIfBasicBlock(LoopBasicBlocksCopy, true);
    BasicBlock *BlockToStop = findIfBasicBlock(LoopBasicBlocksCopy, false);
    deleteAllBlocksFrom(BlockToStart, BlockToStop, BlocksToDelete);
    LoopBasicBlocksCopy.front()->getTerminator()->setSuccessor(0, BlockToStop);

    for (BasicBlock *BB : BlocksToDelete) {
      BB->eraseFromParent();
    }
    return LoopBasicBlocksCopy.front();
  }

  void loopFission(Loop *L)
  {
    BasicBlock *LoopCopy = copyLoop(L);
    LoopBasicBlocks.front()->getTerminator()->setSuccessor(1, LoopCopy);
  }

  BasicBlock *findIfBasicBlock(std::vector<BasicBlock *> &LoopBasicBlocks, bool findFirst)
  {
    BasicBlock *LastBranchBlock = nullptr;

    for (size_t i = 1; i < LoopBasicBlocks.size(); i++) {
      for (Instruction &I : *LoopBasicBlocks[i]) {
        if (isa<ICmpInst>(&I)) {
          if (findFirst) {
            return LoopBasicBlocks[i];
          }
          LastBranchBlock = LoopBasicBlocks[i];
        }
      }
    }

    return LastBranchBlock;
  }

  void deleteAllBlocksFrom(BasicBlock *Current, BasicBlock *BlockToStop,
                           std::unordered_set<BasicBlock *> &BlocksToDelete)
  {
    BlocksToDelete.insert(Current);

    for (size_t i = 0; i < Current->getTerminator()->getNumSuccessors(); i++) {
      BasicBlock *Successor = Current->getTerminator()->getSuccessor(i);
      if (BlocksToDelete.find(Successor) == BlocksToDelete.end() && Successor != BlockToStop) {
        deleteAllBlocksFrom(Successor, BlockToStop, BlocksToDelete);
      }
    }
  }

  bool runOnLoop(Loop *L, LPPassManager &LPM) override {
    LoopBasicBlocks = L->getBlocksVector();
    loopFission(L);

    BasicBlock *BranchBlock = findIfBasicBlock(LoopBasicBlocks, true);
    BranchInst *Branch = dyn_cast<BranchInst>(BranchBlock->getTerminator()->getSuccessor(1)->getTerminator());
    bool isConditional = Branch->isConditional();

    std::unordered_set<BasicBlock *> BlocksToDelete;
    // Imamo else granu
    if (!isConditional) {
      // Jedini successor else basic blocka
      deleteAllBlocksFrom(Branch->getSuccessor(0), L->getLoopLatch(), BlocksToDelete);
      // True basic block preusmerava na loop latch
      BranchBlock->getTerminator()->getSuccessor(0)->getTerminator()->setSuccessor(0, L->getLoopLatch());
      // False basic block preusmerava na loop latch
      BranchBlock->getTerminator()->getSuccessor(1)->getTerminator()->setSuccessor(0, L->getLoopLatch());
    }
    else {
      // Krecemo brisanje od drugog if-a, koji je zapravo prvi successor BranchBlocka
      deleteAllBlocksFrom(BranchBlock->getTerminator()->getSuccessor(1), L->getLoopLatch(), BlocksToDelete);
      // Ako uslov u if-u nije ispunjen, idemo na loop latch
      BranchBlock->getTerminator()->setSuccessor(1, L->getLoopLatch());
      // Ako je uslov bio ispunjen, nakon true blocka idemo na loop latch
      BranchBlock->getTerminator()->getSuccessor(0)->getTerminator()->setSuccessor(0, L->getLoopLatch());
    }

    for (BasicBlock *BB : BlocksToDelete) {
      BB->eraseFromParent();
    }

    return true;
  }
}; // end of struct OurLoopFissionPass
}  // end of anonymous namespace

char OurLoopFissionPass::ID = 0;
static RegisterPass<OurLoopFissionPass> X("loop-fission", "",
                                      false /* Only looks at CFG */,
                                      false /* Analysis Pass */);