Golang Epoll初体验

本文参考自:https://colobu.com/2019/02/23/1m-go-tcp-connection,本文只是简单运行一个服务端的demo,更多信息请浏览原文


关于Golang的epoll:

Go的net方式在Linux也是通过epoll方式实现的,为什么我们还要再使用epoll方式进行封装呢?原因在于Go将epoll方式封装再内部,对外并没有直接提供epoll的方式来使用。也有一些封装了epoll的库如:evio 等

Epoll server实现:

package main

import (
	"fmt"
	"golang.org/x/sys/unix"
	"log"
	"net"
	"net/http"
	"reflect"
	"sync"
	"syscall"
)

var epoller *epoll

type epoll struct {
     
	fd          int
	connections map[int]net.Conn
	lock        *sync.RWMutex
}
func MkEpoll() (*epoll, error) {
     
	fd, err := unix.EpollCreate1(0)
	if err != nil {
     
		return nil, err
	}
	return &epoll{
     
		fd:          fd,
		lock:        &sync.RWMutex{
     },
		connections: make(map[int]net.Conn),
	}, nil
}
func (e *epoll) Add(conn net.Conn) error {
     
	// Extract file descriptor associated with the connection
	fd := socketFD(conn)
	err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{
     Events: unix.POLLIN | unix.POLLHUP, Fd: int32(fd)})
	if err != nil {
     
		panic(err)
	}
	e.lock.Lock()
	defer e.lock.Unlock()
	e.connections[fd] = conn
	if len(e.connections)%100 == 0 {
     
		log.Printf("add total number of connections: %v", len(e.connections))
	}
	return nil
}
func (e *epoll) Remove(conn net.Conn) error {
     
	fd := socketFD(conn)
	err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_DEL, fd, nil)
	if err != nil {
     
		return err
	}
	e.lock.Lock()
	defer e.lock.Unlock()
	delete(e.connections, fd)
	if len(e.connections)%100 == 0 {
     
		log.Printf("total number of connections: %v", len(e.connections))
	}
	return nil
}
func (e *epoll) Wait() ([]net.Conn, error) {
     
	events := make([]unix.EpollEvent, 100)
	n, err := unix.EpollWait(e.fd, events, 100)
	//events := make([]syscall.EpollEvent, 100)
	//n, err := syscall.EpollWait(e.fd, events, 100)
	if err != nil {
     
		return nil, err
	}
	e.lock.RLock()
	defer e.lock.RUnlock()
	var connections []net.Conn
	for i := 0; i < n; i++ {
     
		conn := e.connections[int(events[i].Fd)]
		connections = append(connections, conn)
	}
	return connections, nil
}
func socketFD(conn net.Conn) int {
     
	//tls := reflect.TypeOf(conn.UnderlyingConn()) == reflect.TypeOf(&tls.Conn{})
	// Extract the file descriptor associated with the connection
	//connVal := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn").Elem()
	tcpConn := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn")
	//if tls {
     
	//	tcpConn = reflect.Indirect(tcpConn.Elem())
	//}
	fdVal := tcpConn.FieldByName("fd")
	pfdVal := reflect.Indirect(fdVal).FieldByName("pfd")
	return int(pfdVal.FieldByName("Sysfd").Int())
}


func main() {
     

	ln, err := net.Listen("tcp", ":8972")
	if err != nil {
     
		panic(err)
	}
	go func() {
     
		if err := http.ListenAndServe(":6060", nil); err != nil {
     
			log.Fatalf("pprof failed: %v", err)
		}
	}()
	epoller, err = MkEpoll()
	if err != nil {
     
		panic(err)
	}
	go start()
	for {
     
		conn, e := ln.Accept()
		if e != nil {
     
			if ne, ok := e.(net.Error); ok && ne.Temporary() {
     
				log.Printf("accept temp err: %v", ne)
				continue
			}
			log.Printf("accept err: %v", e)
			return
		}
		fmt.Println("accept successful....")
		if err := epoller.Add(conn); err != nil {
     
			log.Printf("failed to add connection %v", err.Error())
			conn.Close()
		}
	}
}
func start() {
     
	var buf = make([]byte, 20)
	for {
     
		connections, err := epoller.Wait()
		if err!= nil {
     
			log.Printf("failed to epoll wait %v", err)
			continue
		}
		//if err == unix.EINTR {
     
		//	log.Printf("syscall.EINTR: %v", err)
		//	continue
		//}
		for _, conn := range connections {
     
			if conn == nil {
     
				break
			}
			if _, err := conn.Read(buf); err != nil {
     
				if err := epoller.Remove(conn); err != nil {
     
					log.Printf("failed to remove %v", err)
				}
				conn.Close()
			}
			fmt.Println("read message from client:",string(buf))
		}
	}
}

当wait抛出error的时候可以判断一下是否是:EINTR,查阅资料都说这个错误可以忽略,EINTR(Interrupted system call)。

Client:

package main

import (
	"flag"
	"fmt"
	"log"
	"net"
	"time"
)

func main() {
     
	flag.Parse()
	addr := "127.0.0.1:8972"
	log.Printf("连接到 %s", addr)

	c, err := net.DialTimeout("tcp", addr, 30*time.Second)
	if err != nil {
     
		fmt.Println("failed to connect", err)
		return
	}
	time.Sleep(time.Millisecond)
	defer c.Close()

	log.Printf("完成初始化连接")

	for {
     
		c.Write([]byte("hello world\r\n"))
		fmt.Println("write success...")
		time.Sleep(time.Second*5)
	}
}

你可能感兴趣的:(Go)