//
// Created by strahinja on 3/8/24.
//

#include "OurCallGraph.h"


OurCallGraph::OurCallGraph(llvm::Module &M)
{
  ModuleName = M.getName().str();

  Function *Main = M.getFunction("main");
  if (Main == nullptr) {
    errs() << "Invalid IR!\n";
    exit(1);
  }

  createCallGraph(Main);
}

void OurCallGraph::createCallGraph(llvm::Function *F)
{
  for (const BasicBlock &BB : *F) {
    for (const auto &Instr : BB) {
      if (auto CallInstr = dyn_cast<CallInst>(&Instr)) {
        Function *Callee = CallInstr->getCalledFunction();
        AdjacencyList[F].insert(Callee);
        if (AdjacencyList.find(Callee) == AdjacencyList.end()) {
          AdjacencyList[Callee] = {};
          createCallGraph(Callee);
        }
      }
    }
  }
}

void OurCallGraph::dumpGraphToFile()
{
  std::ofstream File;
  File.open(ModuleName + ".dot");

  File << "digraph \"Call graph: " << ModuleName << "\" {\n";
  File << "\tlabel=\"Our call graph: " << ModuleName << "\";\n\n";

  for (const auto &p : AdjacencyList) {
    File << "\tNode" << p.first << " [shape=record,label=\"{" << p.first->getName().str() << "}\"];\n";

    for (Function *Callee : p.second) {
      File << "\tNode" << p.first << " -> " << "Node" << Callee << ";\n";
    }
  }

  File << "}\n";
}