使用unix domain socket传递file descriptor的例子

刚抄完,先贴上来吧。慢慢再分析。

 

#include <apue.h>
#include <errno.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/uio.h>
#include <sys/socket.h>

#define CL_OPEN "open"
#define CS_OPEN "/tmp/opend"
#define BUFFSIZE 8192


int csopen(char *, int);

int
main(int argc, char* argv[])
{
    int      n, fd;
    char     buf[BUFFSIZE], line[MAXLINE];

    /* read filename to cat from stdin */
    while (fgets(line, MAXLINE, stdin) != NULL)
    {
        if (line[strlen(line) - 1] == '/n')
            line[strlen(line) - 1] = 0;

        /* open the file */
        if ((fd = csopen(line, O_RDONLY)) < 0)
            continue;

        /* and cat to stdout */
        while ((n = read(fd, buf, BUFFSIZE)) > 0)
            if (write(STDOUT_FILENO, buf, n) != n)
                err_sys("write error");

        if (n < 0)
            err_sys("read error");
        close(fd);
    }
}

#define  CONTROLLEN  CMSG_LEN(sizeof(int))

static struct cmsghdr    *cmptr = NULL;

int
recv_fd1(int fd, ssize_t (*userfunc)(int, const void *, size_t))
{
    int           newfd, nr, status;
    char          *ptr;
    char          buf[MAXLINE];
    struct iovec  iov[1];
    struct msghdr msg;

    status = -1;
    for (;;)
    {
        iov[0].iov_base = buf;
        iov[0].iov_len  = sizeof(buf);
        msg.msg_iov     = iov;
        msg.msg_iovlen  = 1;
        msg.msg_name    = NULL;
        msg.msg_namelen = 0;
        if (cmptr == NULL && (cmptr = malloc(CONTROLLEN)) == NULL)
            return -1;
        msg.msg_control    = cmptr;
        msg.msg_controllen = CONTROLLEN;
        if ((nr = recvmsg(fd, &msg, 0)) < 0)
        {
            err_sys("recvmsg error");
        }
        else if (nr == 0)
        {
            err_ret("connection closed by server");
            return -1;
        }

        for (ptr = buf; ptr < &buf[nr]; )
        {
            if (*ptr++ == 0)
            {
                if (ptr != &buf[nr-1])
                    err_dump("message format error");
                status = *ptr & 0xFF; /* prevent sign extension */
                if (status == 0)
                {
                    if (msg.msg_controllen != CONTROLLEN)
                        err_dump("status = 0 but no fd");
                    newfd = *(int *)CMSG_DATA(cmptr);
                }
                else
                {
                    newfd = -status;
                }
                nr -= 2;
            }
        }
        if (nr > 0 && (*userfunc)(STDERR_FILENO, buf, nr) != nr)
            return -1;
        if (status >= 0)
            return newfd;
    }
}

int
csopen(char *name, int oflag)
{
    int          len;
    char         buf[10];
    struct iovec iov[3];
    static int   csfd = -1;

    if (csfd < 0)
    {/* open connection to conn server */
        if ((csfd = cli_conn(CS_OPEN)) < 0)
            err_sys("cli_conn error");
    }

    sprintf(buf, " %d", oflag);
    iov[0].iov_base = CL_OPEN " ";
    iov[0].iov_len  = strlen(CL_OPEN) + 1;
    iov[1].iov_base = name;
    iov[1].iov_len  = strlen(name);
    iov[2].iov_base = buf;
    iov[2].iov_len  = strlen(buf) + 1;  /* null always sent */
    len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len ;
    if (writev(csfd, &iov[0], 3) != len)
        err_sys("writev error");

    /* read back descriptor; returned errors handled by write() */
    return(recv_fd1(csfd, write));
}

#include "apue.h"
#include <errno.h>
#include <syslog.h>
#include <sys/time.h>
#include <sys/select.h>
#include <sys/socket.h>

#define CS_OPEN "/tmp/opend"
#define CL_OPEN "open"
#define MAXARGC 50
#define WHITE   " /t/n"

#define NALLOC  10

int    debug;
char   errmsg[MAXLINE];
int    oflag;
char  *pathname;
int    log_to_stderr = 1;

typedef struct
{
    int    fd;
    uid_t  uid;
}Client;

Client    *client = NULL;
int        client_size;

int   cli_args(int, char **);
int   client_add(int, uid_t);
void  client_del(int);
void  loop(void);
void  request(char *, int, int, uid_t);

static void
client_alloc(void) /* alloc more entries in the client[] array */
{
    int      i;

    if (client == NULL)
        client = malloc(NALLOC * sizeof(Client));
    else
        client = realloc(client, (client_size+NALLOC)*sizeof(Client));
    if (client == NULL)
        err_sys("can't alloc for client array");

    /* initialize the new entries */
    for (i = client_size; i < client_size + NALLOC; i++)
    {
        client[i].fd = -1;
    }

    client_size += NALLOC;
}

/*
 * Called by loop() when connection request from a new client arrives
 */
int
client_add(int fd, uid_t uid)
{
    int     i;

    if (client == NULL)
        client_alloc();

again:
    for (i = 0; i < client_size; i++)
    {
        if (client[i].fd == -1) /* find an available entry */
        {
            client[i].fd = fd;
            client[i].uid = uid;
            return i;
        }
    }

    /* client arry full, realloc for more */
    client_alloc();
    goto again;
}

/* called by loop() when we're done with a client */
void
client_del(int fd)
{
    int     i;
   
    for (i = 0; i < client_size; i++)
    {
        if (client[i].fd == fd)
        {
            client[i].fd = -1;
            return;
        }
    }
//    log_quit("can't find client entry for fd %d", fd);
}

int
main(int argc, char* argv[])
{
    int       c;

    log_open("open.serv", LOG_PID, LOG_USER);

    opterr = 0; /* don't want getopt() write to stderr */
    while ((c = getopt(argc, argv, "d")) != EOF)
    {
        switch (c)
        {
            case 'd': /* debug */
                debug = log_to_stderr = 1;
                break;

            case '?':
                err_quit("unrecongnized option: -%c", optopt);
        }
    }

//    if (debug == 0)
//        daemonize("opend");

    loop();
    return 0;
}

#define  CONTROLLEN  CMSG_LEN(sizeof(int))
static  struct cmsghdr   *cmptr = NULL;

int
send_fd1(int fd, int fd_to_send)
{
    struct iovec    iov[1];
    struct msghdr   msg;
    char            buf[2];

    iov[0].iov_base  = buf;
    iov[0].iov_len   = 2;
    msg.msg_iov      = iov;
    msg.msg_iovlen   = 1;
    msg.msg_name     = NULL;
    msg.msg_namelen  = 0;

    if (fd_to_send < 0)
    {
        msg.msg_control     = NULL;
        msg.msg_controllen  = 0;
        buf[1]  =  -fd_to_send;
        if (buf[1] == 0)
            buf[1] = 1; /* -256 */
    }
    else
    {
        if (cmptr == NULL && (cmptr = malloc(CONTROLLEN)) == NULL)
            return -1;
        cmptr->cmsg_level  = SOL_SOCKET;
        cmptr->cmsg_type   = SCM_RIGHTS;
        cmptr->cmsg_len    = CONTROLLEN;
        msg.msg_control    = cmptr;
        msg.msg_controllen = CONTROLLEN;
        *(int*)CMSG_DATA(cmptr) = fd_to_send;
        buf[1] = 0;
    }
    buf[0] = 0;
    if (sendmsg(fd, &msg, 0) != 2)
        return -1;
    return 0;
}
void
request(char *buf, int nread, int clifd, uid_t uid)
{
    int      newfd;

    if (buf[nread-1] != 0)
    {
        sprintf(errmsg,
           "request from uid %d not null terminated: %*.*s/n",
           uid, nread, nread, buf);
        send_err(clifd, -1, errmsg);
        return;
    }
    log_msg("request: %s, from uid %d", buf, uid);

    /* parse the arguments, set options */
    if (buf_args(buf, cli_args) < 0)
    {
        send_err(clifd, -1, errmsg);
        log_msg(errmsg);
        return;
    }

    if ((newfd = open(pathname, oflag)) < 0)
    {
        sprintf(errmsg, "can't open %s: %s/n",
                pathname, strerror(errno));
        send_err(clifd, -1, errmsg);
        log_msg(errmsg);
        return;
    }

    /* send the descriptor */
    if (send_fd1(clifd, newfd) < 0)
        log_sys("send_fd error");
    log_msg("sent fd %d over fd %d for %s", newfd, clifd, pathname);
    close(newfd);
}

void
loop(void)
{
    int      i, n, maxfd, maxi, listenfd, clifd, nread;
    char     buf[MAXLINE];
    uid_t    uid;
    fd_set   rset, allset;

    FD_ZERO(&allset);

    /* obtain fd to listen for client requests on */
    if ((listenfd = serv_listen(CS_OPEN)) < 0)
        log_sys("serv_listen error");

    FD_SET(listenfd, &allset);
    maxfd = listenfd;
    maxi = -1;

    for ( ; ; )
    {
        rset = allset;    /* rset gets modified each time around */
        if ((n = select(maxfd + 1, &rset, NULL, NULL, NULL)) < 0)
            log_sys("select error");

        if (FD_ISSET(listenfd, &rset))
        {
            /* accept new client request */
            if ((clifd = serv_accept(listenfd, &uid)) < 0)
                log_sys("serv_acept error: %d", clifd);
            i = client_add(clifd, uid);
            FD_SET(clifd, &allset);
            if (clifd > maxfd)
                maxfd = clifd;  /* max fd for select() */
            if (i > maxi)
                maxi = i;       /* max index in client[] array */
            log_msg("new connection: uid %d, fd %d", uid, clifd);
            continue;
        }

        for (i = 0; i <= maxi; i++) /* go through client[] array */
        {
            if ((clifd = client[i].fd) < 0) /* empty entry */
                continue;
            if (FD_ISSET(clifd, &rset))
            {
                /* read argument buffer from client */
                if ((nread = read(clifd, buf, MAXLINE)) < 0)
                {
                    log_sys("read error on fd %d", clifd);
                }
                else if (nread == 0)
                {
                    log_msg("closed: uid %d, fd %d",
                             client[i].uid, clifd);
                    client_del(clifd);
                    FD_CLR(clifd, &allset);
                    close(clifd);
                }
                else
                { /* process client's request */
                    request(buf, nread, clifd, client[i].uid);
                }
            }
        }
    }
}

int
buf_args(char *buf, int (*optfunc)(int, char **))
{
    char     *ptr, *argv[MAXARGC];
    int      argc;

    if (strtok(buf, WHITE) == NULL)
        return -1;
    argv[argc = 0] = buf;
    while ((ptr = strtok(NULL, WHITE)) != NULL)
    {
        if (++argc >= MAXARGC-1)
            return -1;
        argv[argc] = ptr;
    }
    argv[++argc] = NULL;

    return (*optfunc)(argc, argv);
}

int
cli_args(int argc, char* argv[])
{
    if (argc != 3 || strcmp(argv[0], CL_OPEN) != 0)
    {
        strcpy(errmsg, "usage: <pathname> <oflag>/n");
        return -1;
    }
    pathname = argv[1];
    oflag = atoi(argv[2]);
    return 0;
}

你可能感兴趣的:(unix,socket,File,null,domain,Descriptor)