tarjan算法非递归实现求强连通分量

 tarjan算法用来求有向图和无向图的强连通分量,强连通分量的概念请自行百度。

此代码来源于华为软挑的题目,在有向带权图中找长度为3-7的环路;

tarjan算法伪代码:

tarjan(u){
  DFN[u]=Low[u]=++Index // 为节点u设定次序编号和Low初值
  Stack.push(u)   // 将节点u压入栈中
  for each (u, v) in E // 枚举每一条边
    if (v is not visted) // 如果节点v未被访问过
        tarjan(v) // 继续向下找
        Low[u] = min(Low[u], Low[v])
    else if (v in S) // 如果节点u还在栈内
        Low[u] = min(Low[u], DFN[v])
  if (DFN[u] == Low[u]) // 如果节点u是强连通分量的根
  repeat v = S.pop  // 将v退栈,为该强连通分量中一个顶点
  print v
  until (u== v)
    }

 Java实现:

//链式前向星方式存图
class Edge {
    public int next;  //相同起点的下一条边的位置
    public int end;  //边的终点 
    public int w;  //权重
    public Edge(int next, int v, int w) {
        this.next = next;
        this.end = v;
        this.w = w;
    }
    public String toString() {
        return " " + next + " " + end + " " + w;
    }
}

public class Graph {
    final int MaxEdges = 2800000;
    public Edge[] edges = new Edge[MaxEdges];
    public int cnt;
    public List> path;
    public LinkedHashSet strPath;
    private String inputFileName;
    private String outputFileName;
    private Logger logger;
    final private int MinLen = 3;  //环的最小长度
    final private int MaxLen = 7;  //环的最大长度
    Map head = new HashMap<>();

    Set visited = new HashSet<>();
    Set endNodesSet;
    public int dfsCount = 0;
    /*-------------------------------- tarjan 变量--------------------------------------------*/
    int visitTime = 0;
    Deque stack = new ArrayDeque<>();
    Set stackSet = new HashSet<>();
    public List> tarRes = new LinkedList<>();
    Map dfn = new HashMap<>(head.size());
    Map low = new HashMap<>(head.size());
    Set tarVisited = new HashSet<>();
    int resCnt = 0;


    public Graph(String inputFileName, String outputFileName){
        init(inputFileName, outputFileName);
        try {
            loadFile();
        } catch (IOException e) {
            logger.info("Fail: loadFile...");
        }

//---------------------------tarjan------------------------
//        tarjan();
//        ----------  dfs --------------
//        findLoop();
//        output();
//        System.out.println("dfs调用次数:" + dfsCount);

    }

    private void Tarjan() {
        long s = System.currentTimeMillis();
        for (int node : head.keySet()) {
            if (!endNodesSet.contains(node) || tarVisited.contains(node)) continue;
            tarjan(node);
        }
        long e = System.currentTimeMillis();
        System.out.println("tarjan time: " + (double) (e - s) / 1000);
        System.out.println("强连通分量个数:" + tarRes.size());
        sort(tarRes);
        output("src/data/Tarjan.txt", tarRes);
    }

    public void init(String inputFileName, String outputFileName) {
        logger = Logger.getLogger("Graph");
        cnt = 0;
        path = new LinkedList<>();
        strPath = new LinkedHashSet<>();
        this.inputFileName = inputFileName;
        this.outputFileName = outputFileName;
        endNodesSet = new HashSet<>();
    }
    //读取文件,按照链式前向星的方法为图添加边
    public void loadFile() throws IOException {
        long startTime = System.currentTimeMillis();
        System.out.println("inputFile: " + inputFileName);
        File f = new File(inputFileName);
        InputStreamReader inputStreamReader = new InputStreamReader(new FileInputStream(f), "utf-8");
        BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
        String lineText = null;
        while ((lineText = bufferedReader.readLine()) != null) {
            String[] data = lineText.split(",");
            addEdge(Integer.parseInt(data[0]), Integer.parseInt(data[1]), Integer.parseInt(data[2]));

        }
        long endTime = System.currentTimeMillis();
        System.out.println("read file and create graph: " + (double) (endTime - startTime) / 1000);
        System.out.println("edges num: " + cnt);
    }
    /*
    * @Description add edge set
    * @param null
    * @Return
    * @Author sunwb
    * @Date 2020/3/29
    */
    public void addEdge(int u, int v, int w) {
        if (!head.containsKey(u))
            head.put(u, -1);
        Edge e = new Edge(head.get(u), v, 0);
        edges[cnt] = e;
        head.put(u, cnt++);
        endNodesSet.add(v);
    }

    public void findLoop(){
        long start = System.currentTimeMillis();
        for (int node : head.keySet()) {
            if (!endNodesSet.contains(node) || visited.contains(node)) continue;
            LinkedHashSet nodeList = new LinkedHashSet<>();
            nodeList.add(node);
            dfs(node, node, nodeList);
        }
        long end = System.currentTimeMillis();
        System.out.println("findLoop time: " + (double) (end - start) / 1000);
        sort(path);
    }

    //dfs寻找长度为3~7的环路,第一版代码,没有任何优化,纯暴力dfs
    public void dfs(int root, int node, LinkedHashSet nodeList) {
        dfsCount++;
        if (!head.containsKey(node)) return;
        int index = head.get(node);
        if (index < 0) return;
        visited.add(node);
        while (index != -1) {
            Edge e = edges[index];
            if (nodeList.contains(e.end)) {
                List list = new ArrayList<>(7);
                boolean flag = false;
                for (int x : nodeList) {  //此处可加入路径长度计数器,避免list的构建
                    if (!flag && x == e.end) {
                        flag = true;
                    }
                    if (flag) {
                        list.add(x);
                        visited.add(x);
                    }
                }
                if (list.size()<=MaxLen && list.size()>=MinLen) {
                    path.add(change(list));
                }
            } else {
                nodeList.add(e.end);
                dfs(root, e.end, nodeList);
            }
            index = e.next;
            if (index < 0) {
                nodeList.remove(node);
                return;
            }
        }
    }
    /*
    * @Description 改变list的顺序,以最小节点开始
    * @param list :一条循环路径
    * @Return java.util.List
    * @Author sunwb
    * @Date 2020/4/2 23:08
    **/
    private List change(List list) {
        int min = 0;
        for (int i = 1; i < list.size();i++) {
            if (list.get(i) < list.get(min))
                min = i;
        }
        List l = new ArrayList<>(list.size());
        for (int i = 0; i < list.size() - min; i++) {
            l.add(list.get(min+i));
        }
        for (int i = 0; i < min; i++) {
            l.add(list.get(i));
        }
        return l;
    }

    private void sort(List> path) {
        Collections.sort(path, new Comparator>() {
            @Override
            public int compare(List o1, List o2) {
                if (o1.size() != o2.size()) return o1.size()-o2.size();
                else {
                    for (int i = 0; i < o1.size(); i++) {
                        if (o1.get(i) != o2.get(i) ) return o1.get(i) - o2.get(i);
                    }
                }
                return 0;
            }
        });
    }

    /*
    * @Description 将结果输出至文件
    * @param filename
    * @param path
    * @Return void
    * @Author sunwb
    * @Date 2020/4/2 23:09
    **/
    public void output() {
        long start = System.currentTimeMillis();
        for (List l : path) {
            String s = l.toString();
            strPath.add(s.substring(1, s.length()- 1));
        }
        long e1 = System.currentTimeMillis();
        System.out.println("del the same list: " + (double)(e1 - start)/1000);
        try {
            File file = new File(outputFileName);
            if (!file.exists())
                file.createNewFile();
            FileWriter fw = new FileWriter(file.getAbsoluteFile());
            BufferedWriter bufferedWriter = new BufferedWriter(fw);
            bufferedWriter.write(String.valueOf(strPath.size()));
            bufferedWriter.newLine();
            for (String s : strPath) {
                bufferedWriter.write(s);
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
        } catch (IOException e) {
            System.out.println("Fail: create file!");
        }
        long e2 = System.currentTimeMillis();
        System.out.println("output file: " + (double)(e2 - start)/1000);
    }

    /*--------------------------------------------------------------------
    tarjan(u){
  DFN[u]=Low[u]=++Index // 为节点u设定次序编号和Low初值
  Stack.push(u)   // 将节点u压入栈中
  for each (u, v) in E // 枚举每一条边
    if (v is not visted) // 如果节点v未被访问过
        tarjan(v) // 继续向下找
        Low[u] = min(Low[u], Low[v])
    else if (v in S) // 如果节点u还在栈内
        Low[u] = min(Low[u], DFN[v])
  if (DFN[u] == Low[u]) // 如果节点u是强连通分量的根
  repeat v = S.pop  // 将v退栈,为该强连通分量中一个顶点
  print v
  until (u== v)
    }---------------------------------------------------------*/
    /*
    * @Description tarjan算法,伪代码见上
    * @param u 当前遍历到的节点
    * @Return void
    * @Author sunwb
    * @Date 2020/4/3 20:41
    **/
    public void tarjan(int u) {
            dfn.put(u, visitTime);
            low.put(u, visitTime);
            visitTime++;
            stack.push(u);
            stackSet.add(u);
            tarVisited.add(u);
            if (!head.containsKey(u)) return;
            int index = head.get(u);
            while (index != -1) {
                if (!tarVisited.contains(edges[index].end)) {
                    tarjan(edges[index].end);
                    low.put(u, Math.min(low.get(u), low.get(edges[index].end)));
                } else if (stackSet.contains(edges[index].end)) {
                    low.put(u, Math.min(low.get(u), low.get(edges[index].end)));
                }
                index = edges[index].next;
            }
            if (dfn.get(u).equals(low.get(u))) {
                List list = new LinkedList<>();
                int n = stack.peek();
                if (n == u) {
//                    list.add(0, n);
                    stack.pop();
                    stackSet.remove(n);
//                    tarRes.add(list);
                    return;
                }
                while (n != u) {
                    n = stack.pop();
                    stackSet.remove(n);
                    list.add(0, n);
                }
                if (list.size()>2) tarRes.add(list); //大于等于3的环才添加
            }
    }

    /*
    * @Description 自己根据伪代码实现的非递归tarjan算法,没有优化,效率很低
    * @param
    * @Return void
    * @Author sunwb
    * @Date 2020/4/10 20:47
    **/
    public void tarjan() {
        long s = System.currentTimeMillis();
        Set visEdges = new HashSet<>();
        for (int headNode : head.keySet()) {
            if (!endNodesSet.contains(headNode) || tarVisited.contains(headNode)) continue;
            Map preNode = new HashMap<>();
            stack.add(headNode);
            preNode.put(headNode, headNode);
            while (!stack.isEmpty()) {
                if (!tarVisited.contains(headNode)) {
                    dfn.put(headNode, visitTime);
                    low.put(headNode, visitTime);
                    visitTime++;
                    tarVisited.add(headNode);
                    stackSet.add(headNode);
                    stack.push(headNode);
                }
                if (!head.containsKey(headNode)) {
                    headNode = preNode.get(headNode);
                    continue;
                }
                int index = head.get(headNode);
                while (index != -1) {
                    if (visEdges.contains(index)) {
                        index = edges[index].next;
                        if (index < 0) {
                            updateStack(headNode);
                            headNode = preNode.get(headNode);
                            break;
                        }
                        continue;
                    }
                    visEdges.add(index);
                    int end = edges[index].end;
                    if (!tarVisited.contains(end)) {
                        preNode.put(end, headNode);
                        headNode = end;
                        break;
                    } else if (stackSet.contains(end)) {
                        low.put(headNode, Math.min(low.get(headNode), low.get(end)));
                        updateLow(end, headNode, low.get(headNode));
                    }
                    index = edges[index].next;
                    if (index < 0) {
                        updateStack(headNode);
                        headNode = preNode.get(headNode);
                        break;
                    }
                }
            }
        }
        long e = System.currentTimeMillis();
        System.out.println("tarjan time: " + (double) (e - s) / 1000);
        System.out.println("强连通分量个数:" + tarRes.size());
        sort(tarRes);
        output("src/data/tar_jan.txt", tarRes);
//        printTarInfo();
    }
    //存储强连通分量,并更新栈
    private void updateStack(int headNode) {
        if (dfn.get(headNode).equals(low.get(headNode))) {
            List list = new LinkedList<>();
            int n = stack.peek();
            if (n == headNode) {
                stack.pop();
                stackSet.remove(n);
            }
            while (n != headNode) {
                n = stack.pop();
                stackSet.remove(n);
                list.add(0, n);
            }
            if (list.size()>2) tarRes.add(list); //大于等于3的环才添加
        }
    }
    //更新环路节点的low值
    private void updateLow(int end, int headNode, int val) {
        List list = new LinkedList<>();
        if (stack.isEmpty()) return;
        while (!stack.isEmpty() && stack.peek() != end) {
            int n = stack.pop();
            low.put(n, val);
            list.add(0, n);
        }
        for (int n : list) {
            stack.push(n);
        }
    }

    private void printTarInfo() {
        System.out.println("强连通分量个数:" + tarRes.size());
        for (List list : tarRes) {
            System.out.println(list.toString());
        }
//        System.out.println("--------low------- : \n" + low.toString());
    }

    public void output(String outFile, List> res) {
        try {
            File file = new File(outFile);
            if (!file.exists()) file.createNewFile();
            FileWriter fw = new FileWriter(file.getAbsoluteFile());
            BufferedWriter bw = new BufferedWriter(fw);
            bw.write(String.valueOf(res.size()));
            bw.newLine();
            for (List l : res) {
                bw.write(l.toString());
                bw.newLine();
            }
            bw.close();
        } catch (IOException e) {
            System.out.println("Fail: write to file!");
        }
    }

    /*
    * @Description 按照python库networkx中的非递归tarjan逻辑实现,贼快
    * @param
    * @Return void
    * @Author sunwb
    * @Date 2020/4/12 17:05
    **/
    public void newTarjan() {
        long s = System.currentTimeMillis();
        int visTime = 0; //preorder counter;
        Map preorder = new HashMap<>();
        Map lowlink = new HashMap<>();
        Map sccFound = new HashMap<>();
        Deque sccQueue = new ArrayDeque<>();
        for (int node : head.keySet()) {
            Deque queue = new ArrayDeque<>();
            if (!sccFound.containsKey(node)) {
                queue.push(node);
                while (!queue.isEmpty()) {
                    int v = queue.peek();
                    if (!preorder.containsKey(v)) {
                        preorder.put(v, ++visTime);
                    }
                    boolean done = true;
                    if (!head.containsKey(v)) {
                        lowlink.put(v, preorder.get(v));
                        queue.pop();
                        continue;
                    }
                    int index = head.get(v);
                    while (index != -1) {
                        int w = edges[index].end;
                        if (!preorder.containsKey(w)) {
                            queue.push(w);
                            done = false;
                            break;
                        }
                        index = edges[index].next;
                    }
                    if(done) {
                        lowlink.put(v, preorder.get(v));
                        index = head.get(v);
                        while (index != -1) {
                            int w = edges[index].end;
                            if (!sccFound.containsKey(w)) {
                                if (preorder.get(w).intValue() > preorder.get(v).intValue()) {
                                    lowlink.put(v, Math.min(lowlink.get(v), lowlink.get(w)));
                                } else {
                                    lowlink.put(v, Math.min(lowlink.get(v), preorder.get(w)));
                                }
                            }
                            index = edges[index].next;
                        }
                        queue.pop();
                        if (lowlink.get(v).equals(preorder.get(v))) {
                            sccFound.put(v, true);
                            List scc = new ArrayList<>();
                            scc.add(v);
                            while (!sccQueue.isEmpty() && preorder.get(sccQueue.peek()).intValue() > preorder.get(v).intValue()) {
                                int k = sccQueue.pop();
                                sccFound.put(k, true);
                                scc.add(k);
                            }
                            if (scc.size()>2) tarRes.add(scc);
                        } else {
                            sccQueue.push(v);
                        }
                    }
                }
            }

        }
        long e = System.currentTimeMillis();
        System.out.println("tarjan time: " + (double) (e - s) / 1000);
        System.out.println("强连通分量个数:" + tarRes.size());
        sort(tarRes);
//        System.out.println(tarRes.toString());
//        output("src/data/newTarjan.txt", tarRes);

    }
    public static void main(String[] args) {
        String inputFile = "src/data/data12.txt";
        String outputFile = "src/data/answer.txt";
//        String inputFile = "/data/test_data.txt";
//        String outputFile = "/projects/student/result.txt";
        Graph graph = new Graph(inputFile, outputFile);
        graph.newTarjan();

    }
}

C++一维存图加非递归tarjan

#define MAXE  2500000
#define MAXN  1000000

struct Edge {
    uint32 to;  //边的终点
    uint32 w;  //权值
};
struct Node {
    uint32 l;  //邻接表中的左边界
    uint32 r;  //邻接表中的右边界
};

Edge neighborsTable[MAXE];  //一维邻接表
Node G[MAXN];  //通过下标访问对应点在邻接表中的左右边界
//neighborsTable[G[i]]表示以i为起点的边

/*************************** tarjan ***************************************/
vector> tarRes;
void tarjan() {
    clock_t start = clock();
    int visitTime = 0;
    int tarjanCnt = 0;
    unordered_map preorder;
    unordered_map lowlink;
    unordered_map sccFound;
    stack sccQueue;
    for (int node = 0; node < nodeCnt; ++node) {
        stack queue;
        if (!sccFound.count(node)) {
            queue.push(node);
            while (!queue.empty()) {
                uint v = queue.top();
                if (!preorder.count(v))
                    preorder[v] = ++visitTime;
                bool done = true;
                for (uint index = G[v].l; index < G[v].r; ++index) {
                    uint w = neighborsTable[index].to;
                    if (!preorder.count(w)) {
                        queue.push(w);
                        done = false;
                        break;
                    }
                }
                if (done) {
                    lowlink[v] = preorder[v];
                    for (uint index = G[v].l; index < G[v].r; ++index) {
                        uint w = neighborsTable[index].to;
                        if (!sccFound.count(w)) {
                            if (preorder[w] > preorder[v]) {
                                lowlink[v] = min(lowlink[v], lowlink[w]);
                            } else {
                                lowlink[v] = min(lowlink[v], preorder[w]);
                            }
                        }
                    }
                    queue.pop();
                    if (lowlink[v] == preorder[v]) {
                        sccFound[v] = true;
                        vector scc;
                        scc.emplace_back(v);
                        while (!sccQueue.empty() && preorder[sccQueue.top()] > preorder[v]) {
                            uint k = sccQueue.top();
                            sccQueue.pop();
                            sccFound[k] = true;
                            scc.emplace_back(k);
                        }
                        tarjanCnt++;
                    } else
                        sccQueue.push(v);

                }
            }
        }
    }
    clock_t end = clock();
    cout << "tarjan time: " << (double)(end - start) / CLOCKS_PER_SEC << endl;
    cout << "scc size: " << tarjanCnt << endl;

 

 

你可能感兴趣的:(C++,Java基础学习)