Linux下netlink通信的实例代码

Linux下netlink通信的实例代码

源代码共分三个文件:

内核模块:netlink-exam-kern.c

应用接收:netlink-exam-user-recv.c

应用发送:netlink-exam-user-send.c

 

内核模块源码:

//kernel module: netlink-exam-kern.c
#include <linux/config.h>
#include <linux/module.h>
#include <linux/netlink.h>
#include <linux/sched.h>
#include <net/sock.h>
#include <linux/proc_fs.h>

#define BUF_SIZE 16384
static struct sock *netlink_exam_sock;
static unsigned char buffer[BUF_SIZE];
static unsigned int buffer_tail = 0;
static int exit_flag = 0;
static DECLARE_COMPLETION(exit_completion);

static void recv_handler(struct sock * sk, int length)
{
        wake_up(sk->sk_sleep);
}

static int process_message_thread(void * data)
{
        struct sk_buff * skb = NULL;
        struct nlmsghdr * nlhdr = NULL;
        int len;
        DEFINE_WAIT(wait);

        daemonize("mynetlink");

        while (exit_flag == 0) {
                prepare_to_wait(netlink_exam_sock->sk_sleep, &wait, TASK_INTERRUPTIBLE);
                schedule();
                finish_wait(netlink_exam_sock->sk_sleep, &wait); 

                while ((skb = skb_dequeue(&netlink_exam_sock->sk_receive_queue))
                         != NULL) {
                        nlhdr = (struct nlmsghdr *)skb->data;
                        if (nlhdr->nlmsg_len < sizeof(struct nlmsghdr)) {
                                printk("Corrupt netlink message.\n");
                                continue;
                        }
                        len = nlhdr->nlmsg_len - NLMSG_LENGTH(0);
                        if (len + buffer_tail > BUF_SIZE) {
                                printk("netlink buffer is full.\n");
                        }
                        else {
                                memcpy(buffer + buffer_tail, NLMSG_DATA(nlhdr), len);
                                buffer_tail += len;
                        }
                        nlhdr->nlmsg_pid = 0;
                        nlhdr->nlmsg_flags = 0;
                        NETLINK_CB(skb).pid = 0;
                        NETLINK_CB(skb).dst_pid = 0;
                        NETLINK_CB(skb).dst_group = 1;
                        netlink_broadcast(netlink_exam_sock, skb, 0, 1, GFP_KERNEL);
                }
        }
        complete(&exit_completion);
        return 0;
}

static int netlink_exam_readproc(char *page, char **start, off_t off,
                          int count, int *eof, void *data)
{
        int len;

        if (off >= buffer_tail) {
                * eof = 1;
                return 0;
        }
        else {
                len = count;
                if (count > PAGE_SIZE) {
                        len = PAGE_SIZE;
                }
                if (len > buffer_tail - off) {
                        len = buffer_tail - off;
                }
                memcpy(page, buffer + off, len);
                *start = page;
                return len;
        }

}

static int __init netlink_exam_init(void)
{
        netlink_exam_sock = netlink_kernel_create(NETLINK_GENERIC, 0, recv_handler, THIS_MODULE);
        if (!netlink_exam_sock) {
                printk("Fail to create netlink socket.\n");
                return 1;
        }
        kernel_thread(process_message_thread, NULL, CLONE_KERNEL);
        create_proc_read_entry("netlink_exam_buffer", 0444, NULL, netlink_exam_readproc, 0);
        return 0;
}

static void __exit netlink_exam_exit(void)
{
        exit_flag = 1;
        wake_up(netlink_exam_sock->sk_sleep);
        wait_for_completion(&exit_completion);
        sock_release(netlink_exam_sock->sk_socket);
}

module_init(netlink_exam_init);
module_exit(netlink_exam_exit);
MODULE_LICENSE("GPL");


用户接收源码:

//application receiver: netlink-exam-user-recv.c
#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/netlink.h>

#define MAX_MSGSIZE 1024


int main(void)
{
        struct sockaddr_nl saddr, daddr;
        struct nlmsghdr *nlhdr = NULL;
        struct msghdr msg;
        struct iovec iov;
        int sd;
        int ret = 1;

        sd = socket(AF_NETLINK, SOCK_RAW,NETLINK_GENERIC);
        memset(&saddr, 0, sizeof(saddr));
        memset(&daddr, 0, sizeof(daddr));

        saddr.nl_family = AF_NETLINK;      
        saddr.nl_pid = getpid();
        saddr.nl_groups = 1;
        bind(sd, (struct sockaddr*)&saddr, sizeof(saddr));

        nlhdr = (struct nlmsghdr *)malloc(NLMSG_SPACE(MAX_MSGSIZE));

        while (1) {
                memset(nlhdr, 0, NLMSG_SPACE(MAX_MSGSIZE));

                iov.iov_base = (void *)nlhdr;
                iov.iov_len = NLMSG_SPACE(MAX_MSGSIZE);
                msg.msg_name = (void *)&daddr;
                msg.msg_namelen = sizeof(daddr);
                msg.msg_iov = &iov;
                msg.msg_iovlen = 1;

                ret = recvmsg(sd, &msg, 0);
                if (ret == 0) {
                        printf("Exit.\n");
                        exit(0);
                }
                else if (ret == -1) {
                        perror("recvmsg:");
                        exit(1);
                }
                printf("%s", NLMSG_DATA(nlhdr));
        }
 
        close(sd);
}


 

用户发送源码:

//application sender: netlink-exam-user-send.c
#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/netlink.h>

#define MAX_MSGSIZE 1024


int main(int argc, char * argv[])
{
        FILE * fp;
        struct sockaddr_nl saddr, daddr;
        struct nlmsghdr *nlhdr = NULL;
        struct msghdr msg;
        struct iovec iov;
        int sd;
        char text_line[MAX_MSGSIZE];
        int ret = -1;

        if (argc < 2) {
                printf("Usage: %s atextfilename\n", argv[0]);
                exit(1);
        }

        if ((fp = fopen(argv[1], "r")) == NULL) {
                printf("File %s dosen't exist.\n");
                exit(1);
        }

        sd = socket(AF_NETLINK, SOCK_RAW,NETLINK_GENERIC);
        memset(&saddr, 0, sizeof(saddr));
        memset(&daddr, 0, sizeof(daddr));

        saddr.nl_family = AF_NETLINK;      
        saddr.nl_pid = getpid();
        saddr.nl_groups = 0;
        bind(sd, (struct sockaddr*)&saddr, sizeof(saddr));

        daddr.nl_family = AF_NETLINK;
        daddr.nl_pid = 0;
        daddr.nl_groups = 0;

        nlhdr = (struct nlmsghdr *)malloc(NLMSG_SPACE(MAX_MSGSIZE));

        while (fgets(text_line, MAX_MSGSIZE, fp)) {
                memcpy(NLMSG_DATA(nlhdr), text_line, strlen(text_line));
                memset(&msg, 0 ,sizeof(struct msghdr));

                nlhdr->nlmsg_len = NLMSG_LENGTH(strlen(text_line));
                nlhdr->nlmsg_pid = getpid();  /* self pid */
                nlhdr->nlmsg_flags = 0;

                iov.iov_base = (void *)nlhdr;
                iov.iov_len = nlhdr->nlmsg_len;
                msg.msg_name = (void *)&daddr;
                msg.msg_namelen = sizeof(daddr);
                msg.msg_iov = &iov;
                msg.msg_iovlen = 1;
                ret = sendmsg(sd, &msg, 0);
                if (ret == -1) {
                        perror("sendmsg error:");
                }
        }
 
        close(sd);
}


 

你可能感兴趣的:(thread,linux,struct,Module,null,buffer)