/*********************************************************
* Copyright (C) 2006 VMware, Inc. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the
* Free Software Foundation version 2 and no later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* for more details.
*
* You should have received a copy of the GNU General Public License along
* with this program; if not, write to the Free Software Foundation, Inc.,
* 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
*
*********************************************************/
#include "driver-config.h"
#include <linux/kernel.h>
#include <linux/version.h>
#include <linux/socket.h>
#include <linux/if_ether.h>
#include <linux/in.h>
#include <linux/ip.h>
#include "compat_skbuff.h"
#include "compat_module.h"
#include <linux/mutex.h>
#include <linux/netdevice.h>
#include <linux/version.h>
#if COMPAT_LINUX_VERSION_CHECK_LT(3, 2, 0)
# include <linux/module.h>
#else
# include <linux/export.h>
#endif
/*
* All this makes sense only if NETFILTER support is configured in our kernel.
*/
#ifdef CONFIG_NETFILTER
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/poll.h>
#include "vnetFilter.h"
#include "vnetFilterInt.h"
#include "vnetInt.h"
#include "vmnetInt.h"
// VNet_FilterLogPacket.action for dropped packets
#define VNET_FILTER_ACTION_DRP (1)
#define VNET_FILTER_ACTION_DRP_SHORT (2)
#define VNET_FILTER_ACTION_DRP_MATCH (3)
#define VNET_FILTER_ACTION_DRP_DEFAULT (4)
// VNet_FilterLogPacket.action for forwarded packets
#define VNET_FILTER_ACTION_FWD (1<<8 | 1)
#define VNET_FILTER_ACTION_FWD_LOOP (1<<8 | 5)
#define VNET_FILTER_ACTION_FWD_MATCH (1<<8 | 6)
#define VNET_FILTER_ACTION_FWD_DEFAULT (1<<8 | 7)
/* netfilter hooks for filtering. */
static nf_hookfn VNetFilterHookFn;
static struct nf_hook_ops vmnet_nf_ops[] = {
{ .hook = VNetFilterHookFn,
.owner = THIS_MODULE,
.pf = PF_INET,
.hooknum = VMW_NF_INET_LOCAL_IN,
.priority = NF_IP_PRI_FILTER - 1, },
{ .hook = VNetFilterHookFn,
.owner = THIS_MODULE,
.pf = PF_INET,
.hooknum = VMW_NF_INET_POST_ROUTING,
.priority = NF_IP_PRI_FILTER - 1, }
};
/* track if we actually set a callback in IP's filter driver */
static Bool installedFilterCallback = FALSE;
/* rules to use for filtering */
RuleSet *ruleSetHead = NULL; /* linked list of all rules */
int32 numRuleSets = 0; /* number of rule sets in ruleSetHead's linked list */
RuleSet *activeRule = NULL; /* actual rule set for filter callback to use */
/* locks to protect against concurrent accesses. */
static DEFINE_MUTEX(filterIoctlMutex); /* serialize ioctl()s from user space. */
/*
* user/netfilter hook concurrency lock.
* This spinlock doesn't scale well if/when in the future the netfilter
* callbacks can be concurrently executing on multiple threads on multiple
* CPUs, so we should revisit locking for allowing for that in the future.
*/
static DEFINE_SPINLOCK(activeRuleLock);
/*
* Logging.
*
* All logging for development build uses LOG(2, (KERN_INFO ...)) because the default
* log level is set to 1 (vnetInt.h). All ACE logging, i.e. policy driven logging, uses
* printk(KERN_INFO ...).
*/
static uint32 logLevel = VNET_FILTER_LOGLEVEL_NORMAL; /* the current log level */
static void LogPacket(uint16 action, void *header, void *data,
uint32 length, Bool drop);
static int InsertHostFilterCallback(void);
static void RemoveHostFilterCallback(void);
static RuleSet *FindRuleSetById(uint32 id, RuleSet ***prevPtr);
static int CreateRuleSet(uint32 id, uint32 defaultAction);
static void DeleteRule(Rule *rule);
static int DeleteRuleSet(uint32 id);
static int ChangeRuleSet(uint32 id, Bool enable, Bool disable, uint32 action);
static int AddIPv4Rule(uint32 id, VNet_AddIPv4Rule *rule,
VNet_IPv4Address *addressList,
VNet_IPv4Port *portList);
/*
*----------------------------------------------------------------------
*
* DropPacket --
*
* Function is used to record information regarding a packet
* being dropped.
*
* Results:
* void
*
* Side effects:
* Might store information regarding the packet.
*
*----------------------------------------------------------------------
*/
static INLINE void
DropPacket(uint16 action, // IN: reason code
void *header, // IN: packet header
void *data, // IN: packet data
uint32 length) // IN: packet length
{
LogPacket(action, header, data, length, TRUE);
}
/*
*----------------------------------------------------------------------
*
* ForwardPacket --
*
* Function is used to record information regarding a packet
* being forwarded.
*
* Results:
* void
*
* Side effects:
* Might store information regarding the packet.
*
*----------------------------------------------------------------------
*/
static INLINE void
ForwardPacket(uint16 action, // IN: reason code
void *header, // IN: packet header
void *data, // IN: packet data
uint32 length) // IN: packet length
{
#ifdef DBG
LogPacket(action, header, data, length, FALSE);
#endif
}
/*
*----------------------------------------------------------------------
*
* VNetFilterHookFn --
*
* Function is registered as a callback function with the host's
* IP stack. This function can be used to filter on specified protocols
* IP addresses, and/or local and remote ports. It makes use of the Linux
* netfilter infrastructure, by inserting this function in netfilter at a
* priority 1 higher than iptables, so that we don't have to worry about
* any existing iptables based firewall rules on the Linux hosts.
*
* Results:
* NF_ACCEPT or NF_DROP.
*
* Side effects:
* None besides those described above.
*
*----------------------------------------------------------------------
*/
#define DEBUG_HOST_FILTER 0
#if DEBUG_HOST_FILTER
#define HostFilterPrint(a) printk a
#else
#define HostFilterPrint(a)
#endif
static unsigned int
#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 13, 0)
VNetFilterHookFn(unsigned int hooknum, // IN:
#else
VNetFilterHookFn(const struct nf_hook_ops *ops, // IN:
#endif
#ifdef VMW_NFHOOK_USES_SKB
struct sk_buff *skb, // IN:
#else
struct sk_buff **pskb, // IN:
#endif
const struct net_device *in, // IN:
const struct net_device *out, // IN:
int (*okfn)(struct sk_buff *)) // IN:
{
#ifndef VMW_NFHOOK_USES_SKB
struct sk_buff *skb = *pskb;
#endif
struct iphdr *ip;
uint32 remoteAddr;
uint16 localPort;
uint16 remotePort;
uint8 *packet;
uint8 *packetHeader;
int packetLength;
RuleSet *currRuleSet;
Bool blockByDefault;
Bool transmit; /* TRUE if transmitting, FALSE is receiving */
Rule *currRule;
unsigned int verdict = NF_ACCEPT;
unsigned long flags;
/* Early checks to see we should even care. */
if (skb->protocol != htons(ETH_P_IP)) {
return verdict;
}
spin_lock_irqsave(&activeRuleLock, flags);
currRuleSet = activeRule;
// ASSERT(currRuleSet);
/*
* Function uses a local copy of ruleSetHead so that we're
* not adversely affected by any rule changes that might occur
* while this function is running.
*/
blockByDefault = currRuleSet->action == VNET_FILTER_RULE_BLOCK;
/* When the host transmits, hooknum is VMW_NF_INET_POST_ROUTING. */
/* When the host receives, hooknum is VMW_NF_INET_LOCAL_IN. */
#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 13, 0)
transmit = (hooknum == VMW_NF_INET_POST_ROUTING);
#else
transmit = (ops->hooknum == VMW_NF_INET_POST_ROUTING);
#endif
packetHeader = compat_skb_network_header(skb);
ip = (struct iphdr*)packetHeader;
if (transmit) {
/* skb all set up for us. */
packet = compat_skb_transport_header(skb);
} else {
/* skb hasn't had a chance to be processed by TCP yet. */
packet = compat_skb_network_header(skb) + (ip->ihl << 2);
}
HostFilterPrint(("PacketFilter: IP ver %d ihl %d tos %d len %d id %d\n"
" offset %d ttl %d proto %d xsum %d\n"
" src 0x%08x dest 0x%08x %s\n",
ip->version, ip->ihl, ip->tos, ip->tot_len, ip->id,
ip->frag_off, ip->ttl, ip->protocol, ip->check,
ip->saddr, ip->daddr, transmit ? "OUTGOING":"INCOMING"));
/*
* For incoming packets, there should be a skb->dev associated with it, with
* a populated L2 address length.
*/
if (skb->dev && skb->dev->hard_header_len) {
packetLength = skb->len - skb->dev->hard_header_len - (ip->ihl << 2);
} else {
/*
* In certain cases, compat_skb_mac_header() has been observed to be NULL. Don't
* know why, but in such cases, this calculation will lead to a negative
* packetLength, and the packet to be dropped.
*/
packetLength = skb->len -
(compat_skb_network_header(skb) - compat_skb_mac_header(skb)) -
(ip->ihl << 2);
}
if (packetLength < 0) {
HostFilterPrint(("PacketFilter: ill formed packet for IPv4\n"));
HostFilterPrint(("skb: len %d h.raw %p nh.raw %p mac.raw %p, packetLength %d\n",
skb->len, compat_skb_transport_header(skb),
compat_skb_network_header(skb),
compat_skb_mac_header(skb), packetLength));
verdict = NF_DROP;
DropPacket(VNET_FILTER_ACTION_DRP_SHORT, packetHeader, packet, 0);
goto out_unlock;
}
remoteAddr = transmit ? ip->daddr : ip->saddr;
/* always allow 127/8. */
if ((remoteAddr & 0xff) == 127) {
HostFilterPrint(("PacketFilter: allowing %s loopback 0x%08x\n",
transmit ? "outgoing" : "incoming",
remoteAddr));
ForwardPacket(VNET_FILTER_ACTION_FWD_LOOP,
packetHeader, packet, packetLength);
goto out_unlock;
}
/* If we're dealing with TCP or UDP, then extract the port information */
if (ip->protocol == IPPROTO_TCP || ip->protocol == IPPROTO_UDP) {
uint16 srcPort, dstPort; /* used to extract port information from packet */
if (packetLength < 4) {
HostFilterPrint(("PacketFilter: payload too short for "
"TCP or UDP: %d\n", packetLength));
verdict = NF_DROP;
DropPacket(VNET_FILTER_ACTION_DRP_SHORT,
packetHeader, packet, packetLength);
goto out_unlock;
}
/* Retrieve UDP/TCP port info */
srcPort = *((uint16*)&packet[0]);
dstPort = *((uint16*)&packet[2]);
if (transmit) { /* transmit */
localPort = ntohs(srcPort);
remotePort = ntohs(dstPort);
} else { /* receive */
localPort = ntohs(dstPort);
remotePort = ntohs(srcPort);
}
HostFilterPrint(("PacketFilter: got local port %d remote port %d\n",
localPort, remotePort));
} else {
/* these mostly exist to silence compiler warning about uninit variables */
localPort = 0;
remotePort = 0;
}
currRule = currRuleSet->list;
/* traverse all the rules in the rule set */
while (currRule != NULL) {
uint32 i;
Bool matchedAddress;
/* if direction doesn't match rule, then skip */
if ((currRule->direction == VNET_FILTER_DIRECTION_IN && transmit) ||
(currRule->direction == VNET_FILTER_DIRECTION_OUT && !transmit)) {
HostFilterPrint(("PacketFilter: didn't match direction\n"));
/* wrong direction */
goto skipRule;
}
/*
* Check if the packet's address matches the rule. If the list is empty
* then this means we don't care about address and it's considered a match.
*/
matchedAddress = (currRule->addressListLen == 0); /* empty list means don't care */
for (i = 0; i < currRule->addressListLen; ++i) {
if ((remoteAddr & currRule->addressList[i].ipv4Mask) ==
currRule->addressList[i].ipv4Addr) {
matchedAddress = TRUE;
HostFilterPrint(("PacketFilter: rule matched ip addr %u: "
"0x%08x == 0x%08x\n", i, remoteAddr,
currRule->addressList[i].ipv4Addr));
break;
} else {
HostFilterPrint(("PacketFilter: rule not match ip addr %u: "
"0x%08x != 0x%08x\n", i, remoteAddr,
currRule->addressList[i].ipv4Addr));
}
}
if (!matchedAddress) {
HostFilterPrint(("PacketFilter: rule didn't match ip addr 0x%08x\n",
remoteAddr));
/* ip addr doesn't match */
goto skipRule;
}
/*
* Check the protocol. ~0 (0xffff) means we don't care about the
* protocol and it's considered a match.
*/
if (currRule->proto != 0xffff && currRule->proto != ip->protocol) {
HostFilterPrint(("PacketFilter: didn't match protocol: %u != %u\n",
ip->protocol, currRule->proto));
/* protocol doesn't match */
goto skipRule;
}
/*
* If the protocol is TCP or UDP then check the port list. If the list is empty
* then this means we don't care about ports and it's considered a match.
*/
if (currRule->proto == IPPROTO_TCP || currRule->proto == IPPROTO_UDP) {
/* An empty list means the rule don't care about port numbers*/
Bool matchedPort = (currRule->portListLen == 0);
for (i = 0; i < currRule->portListLen; ++i) {
RulePort *portRule = currRule->portList + i;
Bool matchedLocal, matchedRemote; /* improves readability */
/*
* It's presumed that if portRule->localPortLow == ~0 then
* portRule->localPortHigh == ~0. Similiar story for the
* remote ports.
*/
matchedLocal = (localPort >= portRule->localPortLow &&
localPort <= portRule->localPortHigh) ||
portRule->localPortLow == ~0;
matchedRemote = (remotePort >= portRule->remotePortLow &&
remotePort <= portRule->remotePortHigh) ||
portRule->remotePortLow == ~0;
if (matchedLocal && matchedRemote) {
HostFilterPrint(("PacketFilter: matched rule's "
"port element %u\n", i));
matchedPort = TRUE;
break;
}
HostFilterPrint(("PacketFilter: didn't match rule's "
"port element %u\n", i));
HostFilterPrint(("-- local %4u not in range [%4u, %4u] or \n",
localPort, portRule->localPortLow,
portRule->localPortHigh));
HostFilterPrint(("-- remote %4u not in range [%4u, %4u]\n",
remotePort, portRule->remotePortLow,
portRule->remotePortHigh));
}
if (!matchedPort) {
HostFilterPrint(("PacketFilter: rule didn't match port "
"(local %u remote %u)\n", localPort, remotePort));
/* port doesn't match */
goto skipRule;
}
}
/* rule matches so follow orders */
if (currRule->action == VNET_FILTER_RULE_ALLOW) {
HostFilterPrint(("PacketFilter: found match, forwarding\n"));
ForwardPacket(VNET_FILTER_ACTION_FWD_MATCH,
packetHeader, packet, packetLength);
goto out_unlock;
} else {
HostFilterPrint(("PacketFilter: found match, dropping\n"));
verdict = NF_DROP;
DropPacket(VNET_FILTER_ACTION_DRP_MATCH,
packetHeader, packet, packetLength);
goto out_unlock;
}
skipRule:
currRule = currRule->next;
}
/* Forward or drop packet based on the default rule */
HostFilterPrint(("PacketFilter: Didn't find match for %s "
"%u.%u.%u.%u, %s packet\n",
transmit ? "outgoing" : "incoming",
remoteAddr & 0xff, (remoteAddr >> 8) & 0xff,
(remoteAddr >> 16) & 0xff, (remoteAddr >> 24) & 0xff,
blockByDefault ? "drop" : "forward"));
if (blockByDefault) {
verdict = NF_DROP;
DropPacket(VNET_FILTER_ACTION_DRP_DEFAULT,
packetHeader, packet, packetLength);
} else {
ForwardPacket(VNET_FILTER_ACTION_FWD_DEFAULT,
packetHeader, packet, packetLength);
}
out_unlock:
spin_unlock_irqrestore(&activeRuleLock, flags);
return verdict;
}
/*
*----------------------------------------------------------------------
*
* InsertHostFilterCallback --
*
* Function registers a hook in the host's IP stack.
*
* Results:
* 0 on success (or if hook already installed),
* errno on failure.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static int
InsertHostFilterCallback(void)
{
uint32 i;
int retval = 0;
LOG(2, (KERN_INFO "vnet filter inserting callback\n"));
if (installedFilterCallback) {
LOG(2, (KERN_INFO "vnet filter callback already registered\n"));
goto end;
}
/* Register netfilter hooks. */
for (i = 0; i < ARRAY_SIZE(vmnet_nf_ops); i++) {
if ((retval = nf_register_hook(&vmnet_nf_ops[i])) >= 0) {
continue;
}
/* Encountered an error, back out. */
LOG(2, (KERN_INFO "vnet filter failed to register callback %d: %d\n",
i, retval));
while (i--) {
nf_unregister_hook(&vmnet_nf_ops[i]);
}
goto end;
}
installedFilterCallback = TRUE;
LOG(2, (KERN_INFO "Successfully set packet filter function\n"));
end:
return retval;
}
/*
*----------------------------------------------------------------------
*
* RemoveHostFilterCallback --
*
* Function deregisters a hook in the host's IP stack.
*
* Results:
* void
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static void
RemoveHostFilterCallback(void)
{
int i;
LOG(2, (KERN_INFO "vnet filter removing callback\n"));
if (installedFilterCallback) {
LOG(2, (KERN_INFO "filter callback was installed: removing filter\n"));
for (i = ARRAY_SIZE(vmnet_nf_ops) - 1; i >= 0; i--) {
nf_unregister_hook(&vmnet_nf_ops[i]);
}
installedFilterCallback = FALSE;
}
LOG(2, (KERN_INFO "vnet filter remove callback done\n"));
}
/*
*----------------------------------------------------------------------
*
* FindRuleSetById --
*
* Function is given an ID for a rule set, and returns a
* pointer to the ruleset with that ID. The function can
* optionally report what pointer is pointing to this item
* (suitable for removing the item from the linked list -- the
* result might be the prior item's next pointer, or the head).
*
* Results:
* NULL if rule set not found, otherwise pointer to rule set.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static RuleSet *
FindRuleSetById(uint32 id, // IN: id to locate
RuleSet ***prevPtr) // OUT: pointer to the ->next pointer
// (or head) that points to the
// returned item (optional)
{
RuleSet *curr;
RuleSet **prev = NULL;
// ASSERT(id != 0);
curr = ruleSetHead;
prev = &ruleSetHead;
while (curr != NULL) {
if (curr->id == id) {
LOG(2, (KERN_INFO "Found id %u at %p\n", id, curr));
if (prevPtr != NULL) {
*prevPtr = prev;
}
return curr;
}
prev = &curr->next;
curr = curr->next;
}
LOG(2, (KERN_INFO "Didn't find ruleset with id %u\n", id));
/* won't overwrite *prevPtr with NULL */
return NULL;
}
/*
*----------------------------------------------------------------------
*
* CreateRuleSet --
*
* Function creates a new rule set with a specified ID and
* default action. Call will fail if failed to alloc memory,
* or if ID is already in use, or if maximum number of
* rule sets have already been created.
*
* Results:
* Returns 0 on success, and otherwise returns errno.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static int
CreateRuleSet(uint32 id, // IN: requested ID for new rule set
uint32 defaultAction) // IN: default action for rule set
{
RuleSet *newRuleSet;
RuleSet *curr;
/* check if too many rule sets already exist */
if (numRuleSets >= MAX_RULE_SETS) {
LOG(2, (KERN_INFO "filter already has all rules (%u of %u) allocated\n",
numRuleSets, MAX_RULE_SETS));
return -EOVERFLOW;
}
/* check if ID is already in use */
curr = FindRuleSetById(id, NULL);
if (curr != NULL) {
LOG(2, (KERN_INFO "filter already has id %u\n", id));
return -EEXIST;
}
/* allocate and init new rule set */
newRuleSet = kmalloc(sizeof *newRuleSet, GFP_USER);
if (newRuleSet == NULL) {
LOG(2, (KERN_INFO "filter mem alloc failed\n"));
return -ENOMEM;
}
memset(newRuleSet, 0, sizeof *newRuleSet);
newRuleSet->next = ruleSetHead;
newRuleSet->id = id;
newRuleSet->enabled = FALSE;
newRuleSet->action = (uint16)defaultAction;
newRuleSet->list = NULL;
newRuleSet->numRules = 0;
newRuleSet->tail = &newRuleSet->list;
/* add new rule set to head of linked list */
numRuleSets++;
ruleSetHead = newRuleSet;
LOG(2, (KERN_INFO "filter created ruleset with id %u\n", id));
return 0;
}
/*
*----------------------------------------------------------------------
*
* DeleteRule --
*
* Function frees the memory in a Rule object. This function
* frees the arrays in the Rule, but not an elements that
* are chained on the linked-list via 'next'.
*
* Results:
* None.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static void
DeleteRule(Rule *rule) // IN: Rule to delete.
{
// ASSERT(rule);
if (!rule) {
return;
}
if (rule->addressList) {
kfree(rule->addressList);
rule->addressList = NULL;
}
if (rule->portList) {
kfree(rule->portList);
rule->portList = NULL;
}
kfree(rule);
}
/*
*----------------------------------------------------------------------
*
* DeleteRuleSet --
*
* Function deletes a rule set with a specified ID. Call will fail
* if ID not found or if the current rule set is being used for
* filtering.
*
* Results:
* Returns 0 on success, errno on failure.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static int
DeleteRuleSet(uint32 id) // IN: ID of new rule set to delete
{
RuleSet **prev = NULL;
RuleSet *curr;
Rule *currRule;
/* locate the ruleset with the specified ID */
curr = FindRuleSetById(id, &prev);
if (curr == NULL) {
LOG(2, (KERN_INFO "filter did not find id %u to delete\n", id));
return -ESRCH;
}
LOG(2, (KERN_INFO "found id %u\n", id));
/* check if in use */
if (curr->enabled) {
LOG(2, (KERN_INFO "Can't delete id %u since enabled\n", id));
return -EBUSY;
}
/* remove item from linked list */
*prev = curr->next;
/* free rules in rule set */
currRule = curr->list;
curr->list = NULL; /* help mitigate any bugs or races */
while (currRule) {
Rule *temp = currRule->next;
currRule->next = NULL; /* help mitigate any bugs or races */
DeleteRule(currRule);
currRule = temp;
}
kfree(curr);
numRuleSets--;
// ASSERT(numRuleSets >= 0);
return 0;
}
/*
*----------------------------------------------------------------------
*
* ChangeRuleSet --
*
* This function is used to specify which rule set is to be used
* for filtering (or stop using for filtering). If another
* rule set is currently used for filtering then the specified
* rule set will replace it. This funciton can also be used to
* change the default action for any rule set, but this option
* should not be used when disabling a rule set.
*
* Call will fail if ID can't be found, or when attempting to
* disable a rule set that's not enabled.
*
* Results:
* Returns 0 on success, errno on failure.
*
* Side effects:
* May add/remove filter callback.
*
*----------------------------------------------------------------------
*/
static int
ChangeRuleSet(uint32 id, // IN: requested ID of rule set
Bool enable, // IN: TRUE says start using this rule for filtering
Bool disable, // IN: TRUE says stop using this rule for filtering
uint32 action) // IN: default action for rule set
{
RuleSet *curr;
int retval;
unsigned long flags;
// ASSERT(!enable || !disable); /* at most one can be set */
LOG(2, (KERN_INFO "changeruleset %d enable %d disable %d action %x\n", id,
enable, disable, action));
/* locate the specified rule set */
curr = FindRuleSetById(id, NULL);
if (curr == NULL) {
LOG(2, (KERN_INFO "vnet filter can't find ruleset: %u\n", id));
return -ESRCH;
}
if (enable) {
RuleSet *oldActive;
if (action != VNET_FILTER_RULE_NO_CHANGE) {
LOG(2, (KERN_INFO "vnet filter changing default action "
"of active rule set: %u (id %u)\n", action, id));
curr->action = (uint16)action;
}
/* enable new rule */
curr->enabled = TRUE;
/* Grab activeRule spinlock. */
spin_lock_irqsave(&activeRuleLock, flags);
LOG(2, (KERN_INFO "changing active rule from "
"%p (%u) to %p (%u)\n", activeRule,
activeRule ? activeRule->id : 0,
curr, curr->id));
/* make rule active */
oldActive = activeRule;
activeRule = curr;
/* Safe to release activeRule spinlock now. */
spin_unlock_irqrestore(&activeRuleLock, flags);
/*
* Mark old rule as not enabled, except if it's the same
* as the newly enabled rule set.
*/
if (oldActive == NULL) {
// 1) activate (no current active)
LOG(2, (KERN_INFO "No prior rule was active\n"));
} else if (oldActive == curr) {
// 2) activate (current active, and same as this one)
LOG(2, (KERN_INFO "Activated rule that was already active\n"));
} else { /* oldActive != NULL && oldActive != curr */
// 3) activate (current active, and different than this one)
LOG(2, (KERN_INFO "Deactivating old rule: %p (id %u)\n",
oldActive, oldActive->id));
oldActive->enabled = FALSE;
}
if ((retval = InsertHostFilterCallback()) != 0) {
LOG(2, (KERN_INFO "Failed to insert filter in IP\n"));
}
} else if (disable) {
if (!curr->enabled) {
// 4) deactive (but not currently active)
LOG(2, (KERN_INFO "vnet filter tried to deactive a "
"non-active rule: %u\n", id));
if (activeRule) {
// ASSERT(activeRule != curr);
LOG(2, (KERN_INFO "-- current active is %p (id %u)\n",
activeRule, activeRule->id));
} else {
LOG(2, (KERN_INFO "-- no rule is currently active\n"));
}
/* in this case we'll also not change the default action */
return -EINVAL;
}
// 5) deactive (and currently active)
LOG(2, (KERN_INFO "vnet filter deactivating %p (id %u)\n",
curr, id));
RemoveHostFilterCallback();
// ASSERT(activeRule == curr);
/* Grab activeRule spinlock. */
spin_lock_irqsave(&activeRuleLock, flags);
activeRule = NULL;
/* Safe to release activeRule spinlock now. */
spin_unlock_irqrestore(&activeRuleLock, flags);
curr->enabled = FALSE;
if (action != VNET_FILTER_RULE_NO_CHANGE) {
LOG(2, (KERN_INFO "vnet filter changing default action: "
"%u (id %u)\n", action, id));
curr->action = (uint16)action;
}
retval = 0;
} else { /* !enable && !disable */
if (action == VNET_FILTER_RULE_NO_CHANGE) {
// 6) no activate change (and default not changed)
LOG(2, (KERN_INFO "vnet filter got nothing to change\n"));
retval = 0;
}
// 7) no activate change (but default action changed)
curr->action = (uint16)action;
LOG(2, (KERN_INFO "vnet filter changed action: %u\n", action));
retval = 0;
}
return retval;
}
/*
*----------------------------------------------------------------------
*
* AddIPv4Rule --
*
* Function is used to add an IPv4 rule to a rule set.
* Call will fail if failed to alloc memory, or if specified
* ID was not found. The actual rule is not sanity checked,
* as it's presumed the caller did this.
*
* Results:
* Returns 0 on success, errno on failure.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
static int
AddIPv4Rule(uint32 id, // IN: requested ID of rule set
VNet_AddIPv4Rule *rule, // IN: rule to add
VNet_IPv4Address *addressList, // IN: list of addresses
VNet_IPv4Port *portList) // IN: list of ports
{
Rule *newRule;
RuleSet *curr;
// ASSERT(rule && addressList && portList);
/* locate the rule set with the specified ID */
curr = FindRuleSetById(id, NULL);
if (curr == NULL) {
LOG(2, (KERN_INFO "vnet filter can't find ruleset: %u\n", id));
return -ESRCH;
}
/* make sure that we don't have too many rules already */
if (curr->numRules >= MAX_RULES_PER_SET) {
LOG(2, (KERN_INFO "vnet filter has too many rules in ruleset: %u >= %u\n",
curr->numRules, MAX_RULES_PER_SET));
return -EOVERFLOW;
}
/* allocate and init rule */
newRule = kmalloc(sizeof *newRule, GFP_USER);
if (newRule == NULL) {
LOG(2, (KERN_INFO "vnet filter mem alloc failed for rule\n"));
return -ENOMEM;
}
memset(newRule, 0, sizeof *newRule);
newRule->action = (uint16)rule->action;
newRule->direction = (uint16)rule->direction;
newRule->proto = (uint16)rule->proto;
// ASSERT(rule->addressListLen <= 255); /* double-check for data truncation */
newRule->addressListLen = (uint8)rule->addressListLen;
if (newRule->addressListLen == 1 &&
addressList[0].ipv4RemoteAddr == 0 &&
addressList[0].ipv4RemoteMask == 0) {
newRule->addressListLen = 0;
LOG(2, (KERN_INFO "vnet filter address has single don't care rule\n"));
}
// ASSERT(rule->portListLen <= 255); /* double-check for data truncation */
newRule->portListLen = (uint8)rule->portListLen;
if (newRule->portListLen == 1 &&
portList[0].localPortLow == ~0 &&
portList[0].localPortHigh == ~0 &&
portList[0].remotePortLow == ~0 &&
portList[0].remotePortHigh == ~0) {
newRule->portListLen = 0;
LOG(2, (KERN_INFO "vnet filter port has single don't care rule\n"));
}
if (newRule->addressListLen > 0) {
uint32 i;
newRule->addressList =
kmalloc(sizeof(*newRule->addressList) * newRule->addressListLen,
GFP_USER);
if (newRule->addressList == NULL) {
LOG(2, (KERN_INFO "vnet filter mem alloc failed for rule address\n"));
DeleteRule(newRule);
return -ENOMEM;
}
/* could use memcpy(), but this insulates against API changes */
for (i = 0; i < newRule->addressListLen; ++i) {
newRule->addressList[i].ipv4Addr = addressList[i].ipv4RemoteAddr;
newRule->addressList[i].ipv4Mask = addressList[i].ipv4RemoteMask;
}
}
if (newRule->portListLen > 0) {
uint32 i;
newRule->portList =
kmalloc(sizeof(*newRule->portList) * newRule->portListLen, GFP_USER);
if (newRule->portList == NULL) {
LOG(2, (KERN_INFO "vnet filter mem alloc failed for rule port\n"));
DeleteRule(newRule);
return -ENOMEM;
}
/* could use memcpy(), but this insulates against API changes */
for (i = 0; i < newRule->portListLen; ++i) {
newRule->portList[i].localPortLow = portList[i].localPortLow;
newRule->portList[i].localPortHigh = portList[i].localPortHigh;
newRule->portList[i].remotePortLow = portList[i].remotePortLow;
newRule->portList[i].remotePortHigh = portList[i].remotePortHigh;
}
}
LOG(2, (KERN_INFO "adding rule with %u addresses and %u ports\n",
newRule->addressListLen, newRule->portListLen));
/* add rule to rule set */
newRule->next = NULL;
*(curr->tail) = newRule;
curr->tail = &(newRule->next);
++curr->numRules;
LOG(2, (KERN_INFO "Added rule %p to set %p, count now %u\n",
newRule, curr, curr->numRules));
return 0;
}
/*
*----------------------------------------------------------------------------
*
* VNetFilter_HandleUserCall --
*
* Handle the subcommands from the SIOCSFILTERRULES ioctl command.
* We end up copying the VNet_RuleHeader bytes twice from userland,
* once from the calling function, and once here after we've figured out
* what sub-command we are dealing with.
*
* Returns:
* 0 on success.
* errno on failure.
*
* Side effects:
* May add/remove filter callback.
*
*----------------------------------------------------------------------------
*/
int
VNetFilter_HandleUserCall(VNet_RuleHeader *ruleHeader, // IN: command header
unsigned long ioarg) // IN: ptr to user data
{
int retval = 0;
/* Serialize all ioctl()s. */
retval = mutex_lock_interruptible(&filterIoctlMutex);
if (retval != 0) {
return retval;
}
switch (ruleHeader->type) {
case VNET_FILTER_CMD_CREATE_RULE_SET: {
VNet_CreateRuleSet createRequest;
if (copy_from_user(&createRequest, (void *)ioarg, sizeof createRequest)) {
retval = -EFAULT;
goto out_unlock;
}
/* Validate size. */
if (createRequest.header.len != sizeof createRequest) {
LOG(2, (KERN_INFO "invalid length %d/%zd for create filter "
"request\n", createRequest.header.len,
sizeof createRequest));
retval = -EINVAL;
goto out_unlock;
}
if (createRequest.ruleSetId == 0) {
LOG(2, (KERN_INFO "invalid id %u for create filter request\n",
createRequest.ruleSetId));
retval = -EINVAL;
goto out_unlock;
}
if (createRequest.defaultAction != VNET_FILTER_RULE_BLOCK &&
createRequest.defaultAction != VNET_FILTER_RULE_ALLOW) {
LOG(2, (KERN_INFO "invalid action %u for create filter request\n",
createRequest.defaultAction));
retval = -EINVAL;
goto out_unlock;
}
retval = CreateRuleSet(createRequest.ruleSetId,
createRequest.defaultAction);
goto out_unlock;
}
case VNET_FILTER_CMD_DELETE_RULE_SET: {
VNet_DeleteRuleSet deleteRequest;
if (copy_from_user(&deleteRequest, (void *)ioarg, sizeof deleteRequest)) {
retval = -EFAULT;
goto out_unlock;
}
/* Validate size. */
if (deleteRequest.header.len != sizeof deleteRequest) {
LOG(2, (KERN_INFO "invalid length %d/%zd for delete filter "
"request\n", deleteRequest.header.len,
sizeof deleteRequest));
retval = -EINVAL;
goto out_unlock;
}
if (deleteRequest.ruleSetId == 0) {
LOG(2, (KERN_INFO "invalid id %u for delete filter request\n",
deleteRequest.ruleSetId));
retval = -EINVAL;
goto out_unlock;
}
retval = DeleteRuleSet(deleteRequest.ruleSetId);
goto out_unlock;
}
case VNET_FILTER_CMD_CHANGE_RULE_SET: {
VNet_ChangeRuleSet changeRequest;
if (copy_from_user(&changeRequest, (void *)ioarg, sizeof changeRequest)) {
retval = -EFAULT;
goto out_unlock;
}
/* Validate size. */
if (changeRequest.header.len != sizeof changeRequest) {
LOG(2, (KERN_INFO "invalid length %d/%zd for change filter "
"request\n", changeRequest.header.len,
sizeof changeRequest));
retval = -EINVAL;
goto out_unlock;
}
if (changeRequest.ruleSetId == 0) {
LOG(2, (KERN_INFO "invalid id %u for change filter request\n",
changeRequest.ruleSetId));
retval = -EINVAL;
goto out_unlock;
}
if (changeRequest.defaultAction != VNET_FILTER_RULE_NO_CHANGE &&
changeRequest.defaultAction != VNET_FILTER_RULE_BLOCK &&
changeRequest.defaultAction != VNET_FILTER_RULE_ALLOW) {
LOG(2, (KERN_INFO "invalid default action %u for change "
"filter request\n", changeRequest.defaultAction));
retval = -EINVAL;
goto out_unlock;
}
if (changeRequest.activate != VNET_FILTER_STATE_NO_CHANGE &&
changeRequest.activate != VNET_FILTER_STATE_ENABLE &&
changeRequest.activate != VNET_FILTER_STATE_DISABLE) {
LOG(2, (KERN_INFO "invalid activate %u for change filter "
"request\n", changeRequest.activate));
retval = -EINVAL;
goto out_unlock;
}
retval = ChangeRuleSet(changeRequest.ruleSetId,
changeRequest.activate == VNET_FILTER_STATE_ENABLE,
changeRequest.activate == VNET_FILTER_STATE_DISABLE,
changeRequest.defaultAction);
goto out_unlock;
}
case VNET_FILTER_CMD_ADD_IPV4_RULE: {
VNet_AddIPv4Rule *addRequest;
VNet_IPv4Address *addressList = NULL;
VNet_IPv4Port *portList = NULL;
int error = -EINVAL;
uint32 i;
/* Validate size. */
if (ruleHeader->len < sizeof *addRequest) {
LOG(2, (KERN_INFO "short length %d/%zd for add filter rule "
"request\n", ruleHeader->len,
sizeof *addRequest));
retval = -EINVAL;
goto out_unlock;
}
if (ruleHeader->len > (sizeof *addRequest +
(sizeof *addressList * MAX_ADDR_PER_RULE) +
(sizeof *portList * MAX_PORT_PER_RULE))) {
LOG(2, (KERN_INFO "long length %d for add filter rule "
"request\n", ruleHeader->len));
retval = -EINVAL;
goto out_unlock;
}
addRequest = kmalloc(ruleHeader->len, GFP_USER);
if (!addRequest) {
LOG(2, (KERN_INFO "couldn't allocate memory to add filter rule\n"));
retval = -ENOMEM;
goto out_unlock;
}
if (copy_from_user(addRequest, (void *)ioarg, ruleHeader->len)) {
error = -EFAULT;
goto out_error;
}
if (addRequest->addressListLen <= 0 ||
addRequest->addressListLen > MAX_ADDR_PER_RULE) {
LOG(2, (KERN_INFO "add filter rule: invalid addr list length: %u\n",
addRequest->addressListLen));
goto out_error;
}
if (addRequest->portListLen <= 0 ||
addRequest->portListLen > MAX_PORT_PER_RULE) {
LOG(2, (KERN_INFO "add filter rule: invalid port list length: %u\n",
addRequest->portListLen));
goto out_error;
}
if (addRequest->header.len !=
(sizeof *addRequest +
addRequest->addressListLen * sizeof(VNet_IPv4Address) +
addRequest->portListLen * sizeof(VNet_IPv4Port))) {
LOG(2, (KERN_INFO "add filter rule: invalid length: %u != %zu\n",
addRequest->header.len, sizeof *addRequest +
addRequest->addressListLen * sizeof(VNet_IPv4Address) +
addRequest->portListLen * sizeof(VNet_IPv4Port)));
goto out_error;
}
/*
* The address list comes after initial struct, and port
* list follows the address list.
*/
addressList = (VNet_IPv4Address *)(addRequest + 1);
portList = (VNet_IPv4Port *)(addressList + addRequest->addressListLen);
if (addRequest->ruleSetId == 0) {
LOG(2, (KERN_INFO "add filter rule: invalid request id %u\n",
addRequest->ruleSetId));
goto out_error;
}
if (addRequest->action != VNET_FILTER_RULE_BLOCK &&
addRequest->action != VNET_FILTER_RULE_ALLOW) {
LOG(2, (KERN_INFO "add filter rule: invalid action %u\n",
addRequest->action));
goto out_error;
}
if (addRequest->direction != VNET_FILTER_DIRECTION_IN &&
addRequest->direction != VNET_FILTER_DIRECTION_OUT &&
addRequest->direction != VNET_FILTER_DIRECTION_BOTH) {
LOG(2, (KERN_INFO "add filter rule: invalid direction %u\n",
addRequest->direction));
goto out_error;
}
/*
* Make sure addr is sane for given mask. Also verify that the address
* and mask, if both zero, are in the first element and the array only
* has one element. This also means that a 0 mask is not allowed in any
* element besides the first.
*/
for (i = 0; i < addRequest->addressListLen; i++) {
if (addressList[i].ipv4RemoteAddr !=
(addressList[i].ipv4RemoteAddr & addressList[i].ipv4RemoteMask)) {
LOG(2, (KERN_INFO "add filter rule got address 0x%08x mask "
"0x%08x for %u\n", addressList[i].ipv4RemoteAddr,
addressList[i].ipv4RemoteMask, i));
addressList[i].ipv4RemoteAddr &= addressList[i].ipv4RemoteMask;
LOG(2, (KERN_INFO "-- changed address to 0x%08x\n",
addressList[i].ipv4RemoteAddr));
}
/*
* If addr==mask==0, then it must be in the first element of the
* address list, and the address list should have only one element.
*/
if (addressList[i].ipv4RemoteAddr == 0 &&
addressList[i].ipv4RemoteMask == 0 &&
(i > 0 || addRequest->addressListLen > 1)) {
LOG(2, (KERN_INFO "add filter rule got violation for zero IP "
"addr/mask\n"));
goto out_error;
}
}
if (addRequest->proto > 0xFF && addRequest->proto != (uint16)~0) {
LOG(2, (KERN_INFO "add filter rule got invalid proto %u\n",
addRequest->proto));
goto out_error;
}
if (addRequest->proto == IPPROTO_TCP ||
addRequest->proto == IPPROTO_UDP) {
for (i = 0; i < addRequest->portListLen; i++) {
if (portList[i].localPortLow > 0xFFFF &&
portList[i].localPortLow != ~0) {
LOG(2, (KERN_INFO "add filter rule invalid localPortLow %u\n",
portList[i].localPortLow));
goto out_error;
}
if (portList[i].localPortHigh > 0xFFFF &&
portList[i].localPortHigh != ~0) {
LOG(2, (KERN_INFO "add filter rule invalid localPortHigh %u\n",
portList[i].localPortHigh));
goto out_error;
}
if (portList[i].remotePortLow > 0xFFFF &&
portList[i].remotePortLow != ~0) {
LOG(2, (KERN_INFO "add filter rule invalid remotePortLow %u\n",
portList[i].remotePortLow));
goto out_error;
}
if (portList[i].remotePortHigh > 0xFFFF &&
portList[i].remotePortHigh != ~0) {
LOG(2, (KERN_INFO "add filter rule invalid remotePortHigh %u\n",
portList[i].remotePortHigh));
goto out_error;
}
/*
* Make sure both low and high ports of a port range specify don't
* care ports.
*/
if ((portList[i].localPortLow == ~0 && portList[i].localPortHigh != ~0) ||
(portList[i].localPortLow != ~0 && portList[i].localPortHigh == ~0) ||
(portList[i].remotePortLow == ~0 && portList[i].remotePortHigh != ~0) ||
(portList[i].remotePortLow != ~0 && portList[i].remotePortHigh == ~0)) {
LOG(2, (KERN_INFO "add filter rule mismatch in don't care "
"status of ports\n"));
LOG(2, (KERN_INFO " -- srcLow %u srcHigh %u dstLow %u dstHigh %u\n",
portList[i].localPortLow, portList[i].localPortHigh,
portList[i].remotePortLow, portList[i].remotePortHigh));
goto out_error;
}
if (portList[i].localPortHigh < portList[i].localPortLow ||
portList[i].remotePortHigh < portList[i].remotePortLow) {
LOG(2, (KERN_INFO "add filter rule high < low on ports\n"));
LOG(2, (KERN_INFO " -- srcLow %u srcHigh %u dstLow %u dstHigh %u\n",
portList[i].localPortLow, portList[i].localPortHigh,
portList[i].remotePortLow, portList[i].remotePortHigh));
goto out_error;
}
/*
* Only allow a don't care on port ranges when it is the only port
* range specified.
*/
if (portList[i].localPortLow == ~0 && portList[i].localPortHigh == ~0 &&
portList[i].remotePortLow == ~0 && portList[i].remotePortHigh == ~0 &&
(i > 0 || addRequest->portListLen > 1)) {
LOG(2, (KERN_INFO "add filter rule incorrect don't "
"care on port list\n"));
goto out_error;
}
}
} else { // proto not TCP or UDP
if (addRequest->portListLen != 1 ||
(portList[0].localPortLow != 0 &&
portList[0].localPortLow != ~0) ||
(portList[0].localPortHigh != 0 &&
portList[0].localPortHigh != ~0) ||
(portList[0].remotePortLow != 0 &&
portList[0].remotePortLow != ~0) ||
(portList[0].remotePortHigh != 0 &&
portList[0].remotePortHigh != ~0)) {
LOG(2, (KERN_INFO "add filter rule missing/unnecessary port "
"information\n"));
for (i = 0; i < addRequest->portListLen; i++) {
LOG(2, (KERN_INFO " -- srcLow %u srcHigh %u dstLow %u dstHigh %u\n",
portList[i].localPortLow, portList[i].localPortHigh,
portList[i].remotePortLow, portList[i].remotePortHigh));
}
goto out_error;
}
}
retval = AddIPv4Rule(addRequest->ruleSetId, addRequest,
addressList, portList);
goto out_unlock;
out_error:
kfree(addRequest);
retval = error;
goto out_unlock;
}
case VNET_FILTER_CMD_ADD_IPV6_RULE:
LOG(2, (KERN_INFO "add filter rule IPv6 not supported\n"));
retval = -EPROTONOSUPPORT;
goto out_unlock;
case VNET_FILTER_CMD_SET_LOG_LEVEL: {
VNet_SetLogLevel setLogLevel;
if (copy_from_user(&setLogLevel, (void *)ioarg, sizeof setLogLevel)) {
retval = -EFAULT;
} else if (setLogLevel.header.len != sizeof setLogLevel) {
LOG(2, (KERN_INFO "set log level invalid header length %u\n",
setLogLevel.header.len));
retval = -EINVAL;
} else if (VNET_FILTER_LOGLEVEL_NONE > setLogLevel.logLevel ||
setLogLevel.logLevel > VNET_FILTER_LOGLEVEL_MAXIMUM) {
LOG(2, (KERN_INFO "set log level invalid value %u\n",
setLogLevel.logLevel));
retval = -EINVAL;
} else {
logLevel = setLogLevel.logLevel;
}
goto out_unlock;
}
default:
LOG(2, (KERN_INFO "add filter rule invalid command %u\n",
ruleHeader->type));
retval = -EINVAL;
goto out_unlock;
}
out_unlock:
mutex_unlock(&filterIoctlMutex);
return retval;
}
/*
*----------------------------------------------------------------------
*
* VNetFilter_Shutdown --
*
* Function is called when the driver is being unloaded.
* This function is responsible for removing the callback
* function from the IP stack and deallocating any remaining
* state.
*
* Results:
* None.
*
* Side effects:
*
*----------------------------------------------------------------------
*/
void
VNetFilter_Shutdown(void)
{
LOG(2, (KERN_INFO "shutting down vnet filter\n"));
RemoveHostFilterCallback();
if (activeRule != NULL) {
LOG(2, (KERN_INFO "disabling the active rule %u\n", activeRule->id));
ChangeRuleSet(activeRule->id, FALSE, TRUE, VNET_FILTER_RULE_NO_CHANGE);
// ASSERT(activeRule == NULL);
}
while (ruleSetHead != NULL) {
LOG(2, (KERN_INFO "Deleteing rule set %u\n", ruleSetHead->id));
DeleteRuleSet(ruleSetHead->id);
}
// ASSERT(numRuleSets == 0);
LOG(2, (KERN_INFO "shut down vnet filter\n"));
}
/*
*----------------------------------------------------------------------
*
* LogPacket --
*
* This function logs a dropped or forwarded packet.
*
* Results:
* None.
*
* Side effects:
* None.
*
*----------------------------------------------------------------------
*/
#define LOGPACKET_HEADER_LEN (20) /* presumed length of 'header': IP (20) */
#define LOGPACKET_DATA_LEN (28) /* TCP/UDP header (20) + 8 payload = 28 */
static void
LogPacket(uint16 action, // IN: reason for packet drop/forward
void *header, // IN: packet header
void *data, // IN: packet data
uint32 length, // IN: packet length (of 'data', not including 'header')
Bool drop) // IN: drop versus forward
{
char packet[(LOGPACKET_HEADER_LEN + LOGPACKET_DATA_LEN) * 3 + 1];
int i, n;
/* something to do? */
if (VNET_FILTER_LOGLEVEL_VERBOSE > logLevel) {
return;
}
/* cap packet length */
if (length > LOGPACKET_DATA_LEN) {
length = LOGPACKET_DATA_LEN;
}
/* build packet string */
n = 0;
if (header) {
for (i = 0; i < LOGPACKET_HEADER_LEN; i++) {
sprintf(&packet[n], "%02x ", ((uint8 *)header)[i]);
n += 3;
}
}
for (i = 0; i < length; i++) {
sprintf(&packet[n], "%02x ", ((uint8 *)data)[i]);
n += 3;
}
/* log packet */
printk(KERN_INFO "packet %s: %s\n", drop ? "dropped" : "forwarded", packet);
}
#endif // CONFIG_NETFILTER