golang练手小项目系列(2)-并发爬虫

本系列整理了10个工作量和难度适中的Golang小项目,适合已经掌握Go语法的工程师进一步熟练语法和常用库的用法。

问题描述:

实现一个网络爬虫,以输入的URL为起点,使用广度优先顺序访问页面。

要点:

实现对多个页面的并发访问,同时访问的页面数由参数 -concurrency 指定,默认为 20。

使用 -depth     指定访问的页面深度,默认为 3。

注意已经访问过的页面不要重复访问。

扩展:

将访问到的页面写入到本地以创建目标网站的本地镜像,注意,只有指定域名下的页面需要采集,写入本地的页面里的元素的href的值需要被修改为指向镜像页面,而不是原始页面。

实现

import (

  "bytes"

  "flag"

  "fmt"

  "golang.org/x/net/html"

  "io"

  "log"

  "net/http"

  "net/url"

  "os"

  "path/filepath"

  "strings"

  "sync"

  "time"

)

type URLInfo struct {

  url string

  depth int

}

var base *url.URL

func forEachNode(n *html.Node, pre, post func(n *html.Node)){

  if pre != nil{

      pre(n)

  }

  for c := n.FirstChild; c != nil; c = c.NextSibling{

      forEachNode(c, pre, post)

  }

  if post != nil{

      post(n)

  }

}

func linkNodes(n *html.Node) []*html.Node {

  var links []*html.Node

  visitNode := func(n *html.Node) {

      if n.Type == html.ElementNode && n.Data == "a" {

        links = append(links, n)

      }

  }

  forEachNode(n, visitNode, nil)

  return links

}

func linkURLs(linkNodes []*html.Node, base *url.URL) []string {

  var urls []string

  for _, n := range linkNodes {

      for _, a := range n.Attr {

        if a.Key != "href" {

            continue

        }

        link, err := base.Parse(a.Val)

        // ignore bad and non-local URLs

        if err != nil {

            log.Printf("skipping %q: %s", a.Val, err)

            continue

        }

        if link.Host != base.Host {

            //log.Printf("skipping %q: non-local host", a.Val)

            continue

        }

        if strings.HasPrefix(link.String(), "javascript"){

            continue

        }

        urls = append(urls, link.String())

      }

  }

  return urls

}

func rewriteLocalLinks(linkNodes []*html.Node, base *url.URL) {

  for _, n := range linkNodes {

      for i, a := range n.Attr {

        if a.Key != "href" {

            continue

        }

        link, err := base.Parse(a.Val)

        if err != nil || link.Host != base.Host {

            continue // ignore bad and non-local URLs

        }

        link.Scheme = ""

        link.Host = ""

        link.User = nil

        a.Val = link.String()

        n.Attr[i] = a

      }

  }

}

func Extract(url string)(urls []string, err error){

  timeout := time.Duration(10 * time.Second)

  client := http.Client{

      Timeout: timeout,

  }

  resp, err := client.Get(url)

  if err != nil{

      fmt.Println(err)

      return nil, err

  }

  if resp.StatusCode != http.StatusOK{

      resp.Body.Close()

      return nil, fmt.Errorf("getting %s:%s", url, resp.StatusCode)

  }

  if err != nil{

      return nil, fmt.Errorf("parsing %s as HTML: %v", url, err)

  }

  u, err := base.Parse(url)

  if err != nil {

      return nil, err

  }

  if base.Host != u.Host {

      log.Printf("not saving %s: non-local", url)

      return nil, nil

  }

  var body io.Reader

  contentType := resp.Header["Content-Type"]

  if strings.Contains(strings.Join(contentType, ","), "text/html") {

      doc, err := html.Parse(resp.Body)

      resp.Body.Close()

      if err != nil {

        return nil, fmt.Errorf("parsing %s as HTML: %v", u, err)

      }

      nodes := linkNodes(doc)

      urls = linkURLs(nodes, u)

      rewriteLocalLinks(nodes, u)

      b := &bytes.Buffer{}

      err = html.Render(b, doc)

      if err != nil {

        log.Printf("render %s: %s", u, err)

      }

      body = b

  }

  err = save(resp, body)

  return urls, err

}

func crawl(url string) []string{

  list, err := Extract(url)

  if err != nil{

      log.Print(err)

  }

  return list

}

func save(resp *http.Response, body io.Reader) error {

  u := resp.Request.URL

  filename := filepath.Join(u.Host, u.Path)

  if filepath.Ext(u.Path) == "" {

      filename = filepath.Join(u.Host, u.Path, "index.html")

  }

  err := os.MkdirAll(filepath.Dir(filename), 0777)

  if err != nil {

      return err

  }

  fmt.Println("filename:", filename)

  file, err := os.Create(filename)

  if err != nil {

      return err

  }

  if body != nil {

      _, err = io.Copy(file, body)

  } else {

      _, err = io.Copy(file, resp.Body)

  }

  if err != nil {

      log.Print("save: ", err)

  }

  err = file.Close()

  if err != nil {

      log.Print("save: ", err)

  }

  return nil

}

func parallellyCrawl(initialLinks string, concurrency, depth int){

  worklist := make(chan []URLInfo, 1)

  unseenLinks := make(chan URLInfo, 1)

  //值为1时表示进入unseenLinks队列,值为2时表示crawl完成

  seen := make(map[string] int)

  seenLock := sync.Mutex{}

  var urlInfos []URLInfo

  for _, url := range strings.Split(initialLinks, " "){

      urlInfos = append(urlInfos, URLInfo{url, 1})

  }

  go func() {worklist <- urlInfos}()

  go func() {

      for{

        time.Sleep(1 * time.Second)

        seenFlag := true

        seenLock.Lock()

        for k := range seen{

            if seen[k] == 1{

              seenFlag = false

            }

        }

        seenLock.Unlock()

        if seenFlag && len(worklist) == 0{

            close(unseenLinks)

            close(worklist)

            break

        }

      }

  }()

  for i := 0; i < concurrency; i++{

      go func() {

        for link := range unseenLinks{

            foundLinks := crawl(link.url)

            var urlInfos []URLInfo

            for _, u := range foundLinks{

              urlInfos = append(urlInfos, URLInfo{u, link.depth + 1})

            }

            go func(finishedUrl string) {

              worklist <- urlInfos

              seenLock.Lock()

              seen[finishedUrl] = 2

              seenLock.Unlock()

            }(link.url)

        }

      }()

  }

  for list := range worklist{

      for _, link := range list {

        if link.depth > depth{

            continue

        }

        seenLock.Lock()

        _, ok := seen[link.url]

        seenLock.Unlock()

        if !ok{

            seenLock.Lock()

            seen[link.url] = 1

            seenLock.Unlock()

            unseenLinks <- link

        }

      }

  }

  fmt.Printf("共访问了%d个页面", len(seen))

}

func main() {

  var maxDepth int

  var concurrency int

  var initialLink string

  flag.IntVar(&maxDepth, "d", 3, "max crawl depth")

  flag.IntVar(&concurrency, "c", 20, "number of crawl goroutines")

  flag.StringVar(&initialLink, "u", "", "initial link")

  flag.Parse()

  u, err := url.Parse(initialLink)

  if err != nil {

      fmt.Fprintf(os.Stderr, "invalid url: %s\n", err)

      os.Exit(1)

  }

  base = u

  parallellyCrawl(initialLink, concurrency, maxDepth)

}

你可能感兴趣的:(golang练手小项目系列(2)-并发爬虫)