学习笔记1:c++实现LRU算法

LeetCode OJ上有一道题目,要编程实现cache的LRU算法,刚看到这道题的时候,我想到了用队列来做,但是若用单链表来做,必须要保存尾节点的上一个节点指针,才能实现快速增加一条数据,编程起来很不方便,所以我采用了双端队列实现,为了处理方便,保存两个带节点指针,一个节点的next指向头结点,另一个节点next指向尾节点。

LRU算法中还要求将已存在的节点放入队尾或者插入节点时要判断此节点是不是存在,所以我使用hash表来快速的定位到此节点,所以就有节点定义

struct node{
	int key;
	int val;
	node *pre;
	node *next;
	node(int k,int v):key(k),val(v),pre(NULL),next(NULL){}
};

hash定义:

#define MAX 10000//hash表最大长度
#define MOD 9991//大质数

node *Hash[MAX][20];
int len[MAX];//对应每个键值的长度

hash函数为:

int _hash(int key){
	return key%MOD;
}

双端队列(DQueue)要提供插入一个节点(insert)、删除头结点(delHead)、获取队列大小(size)、根据key值获取节点的value值功能(find);

其中插入节点时需要判断该节点是不是存在,若存在,需要将其放到队尾,并给其赋新值;根据key值获取节点value值时也要判断节点是否存在,若存在也需将其放入队尾,并返回该节点value值。所以我增加了一个search函数,

函数声明为:

node * search(int key);

该函数功能为,在hash表中查找键值为key的节点,若找不到返回NULL,找到时将该节点拿出来插入到队尾,并返回该节点,具体实现为:

node * search(int key){
	int pos=_hash(key);//找到hash表位置
	for(int i=0;i<len[pos];i++){
		if(key==Hash[pos][i]->key){
			node *t=Hash[pos][i];
			if(t->pre==NULL&&t->next==NULL)return t;//如果该队列中只有一个元素时返回该节点
			if(t->pre==NULL){//若为头结点,还需要对head指针进行操作
				head->next=t->next;
				t->next->pre=NULL;
				t->pre=tail->next;
				tail->next->next=t;
				tail->next=t;
				return t;
			}
			if(t->next==NULL){//若为尾节点
				return t;
			}
			t->pre->next=t->next;//既不是头结点也不是尾节点,只需要对tail指针进行操作
			t->next->pre=t->pre;
			t->pre=tail->next;
			tail->next->next=t;
			t->next=NULL;
			tail->next=t;
			return t;
		}
	}
	return NULL;
}
所以整个DQueue双端队列类实现如下:

class DQueue{
	node *head;
	node *tail;
	int Size;
	node *Hash[MAX][20];
	int len[MAX];
	int _hash(int key){
		return key%MOD;
	}
	node * search(int key){
		int pos=_hash(key);
		for(int i=0;i<len[pos];i++){
			if(key==Hash[pos][i]->key){
				node *t=Hash[pos][i];
				if(t->pre==NULL&&t->next==NULL)return t;
				if(t->pre==NULL){
					head->next=t->next;
					t->next->pre=NULL;
					t->pre=tail->next;
					tail->next->next=t;
					tail->next=t;
					return t;
				}
				if(t->next==NULL){
					return t;
				}
				t->pre->next=t->next;
				t->next->pre=t->pre;

				t->pre=tail->next;
				tail->next->next=t;
				t->next=NULL;
				tail->next=t;
				return t;
			}
		}
		return NULL;
	}
public:
	DQueue(){
		head=new node(0,0);
		tail=new node(0,0);
		memset(len,0,sizeof(len));
		Size=0;
	}
	int insert(int key,int val){
		
		node *i=search(key);
		if(i!=NULL){
			i->val=val;
			return 1;
		}
		
		node *t=new node(key,val);
		if(!Size){
			head->next=t;
		}
		t->pre=tail->next;
		if(tail->next!=NULL)tail->next->next=t;
		int pos=_hash(t->key);
		Hash[pos][len[pos]++]=t;
		tail->next=t;
		Size++;
		return 1;
	}
	int delHead(){
		node *p=head->next;
		if(p!=NULL){
			head->next=p->next;
			if(p->next!=NULL)p->next->pre=NULL;
			delete[] p;
			Size--;
			return 1;
		}
		return 0;
	}
	int find(int key){
		node *f=search(key);
		if(f!=NULL)return f->val;
		return -1;
	}
	int size(){
		return Size;
	}
};
LRU类实现就比较简单了:

class LRUCache{
	DQueue *dq;
	int capa;
public:
	LRUCache(int capacity){
		capa=capacity;
		dq=new DQueue();
	}
	int get(int key){
		return dq->find(key);
	}
	void set(int key,int value){
		dq->insert(key,value);//先插入一个节点
		if(dq->size()>capa){//若队列满则删掉头结点
			dq->delHead();
		}
	}
};

整个代码实现加测试程序如下:

#include <iostream>
#include <windows.h>
#include <random>
#include <stdlib.h>
#include <time.h>
#include <algorithm>
#include <vector>
#include <string>
#include <map>
#include <cmath>
using namespace std;

#define MAX 10000
#define MOD 9991
struct node{
	int key;
	int val;
	node *pre;
	node *next;
	node(int k,int v):key(k),val(v),pre(NULL),next(NULL){}
};


class DQueue{
	node *head;
	node *tail;
	int Size;
	node *Hash[MAX][20];
	int len[MAX];
	int _hash(int key){
		return key%MOD;
	}
	node * search(int key){
		int pos=_hash(key);
		for(int i=0;i<len[pos];i++){
			if(key==Hash[pos][i]->key){
				node *t=Hash[pos][i];
				if(t->pre==NULL&&t->next==NULL)return t;
				if(t->pre==NULL){
					head->next=t->next;
					t->next->pre=NULL;
					t->pre=tail->next;
					tail->next->next=t;
					tail->next=t;
					return t;
				}
				if(t->next==NULL){
					return t;
				}
				t->pre->next=t->next;
				t->next->pre=t->pre;

				t->pre=tail->next;
				tail->next->next=t;
				t->next=NULL;
				tail->next=t;
				return t;
			}
		}
		return NULL;
	}
public:
	DQueue(){
		head=new node(0,0);
		tail=new node(0,0);
		memset(len,0,sizeof(len));
		Size=0;
	}
	int insert(int key,int val){
		
		node *i=search(key);
		if(i!=NULL){
			i->val=val;
			return 1;
		}
		
		node *t=new node(key,val);
		if(!Size){
			head->next=t;
		}
		t->pre=tail->next;
		if(tail->next!=NULL)tail->next->next=t;
		int pos=_hash(t->key);
		Hash[pos][len[pos]++]=t;
		tail->next=t;
		Size++;
		return 1;
	}
	int delHead(){
		node *p=head->next;
		if(p!=NULL){
			head->next=p->next;
			if(p->next!=NULL)p->next->pre=NULL;
			delete[] p;
			Size--;
			return 1;
		}
		return 0;
	}
	int find(int key){
		node *f=search(key);
		if(f!=NULL)return f->val;
		return -1;
	}
	int size(){
		return Size;
	}
};

class LRUCache{
	DQueue *dq;
	int capa;
public:
	LRUCache(int capacity){
		capa=capacity;
		dq=new DQueue();
	}
	int get(int key){
		return dq->find(key);
	}
	void set(int key,int value){
		dq->insert(key,value);
		if(dq->size()>capa){
			dq->delHead();
		}
	}
};
void Test()//测试程序
{
	//freopen("C:\\in.txt","r",stdin);
	LRUCache p(10);
	p.set(10,13),p.set(3,17),p.set(6,11),p.set(10,5),p.set(9,10);
	cout<<p.get(13);
	p.set(2,19);
	cout<<p.get(2);
	cout<<p.get(3);
	p.set(5,25);
	cout<<p.get(8);
	p.set(9,22),p.set(5,5),p.set(1,30);
	cout<<p.get(11);
	p.set(9,12);
	cout<<p.get(7);
	cout<<p.get(5);
	cout<<p.get(8);
	cout<<p.get(9);
	p.set(4,30),p.set(9,3);
	cout<<p.get(9);
	cout<<p.get(10);
	cout<<p.get(10);
	p.set(6,14);p.set(3,1);
	cout<<p.get(3);
	p.set(10,11);
	cout<<p.get(8);
	p.set(2,14);
	cout<<p.get(1);
	cout<<p.get(5);
	cout<<p.get(4);
	p.set(11,4),p.set(12,24),p.set(5,18);
	cout<<p.get(13);
	p.set(7,23);
	cout<<p.get(8);
	cout<<p.get(12);
	p.set(3,27),p.set(2,12);
	cout<<p.get(5);
	p.set(2,9),p.set(13,4);p.set(8,18),p.set(1,7);
	cout<<p.get(6);
	p.set(9,29),p.set(8,21);
	cout<<p.get(5);
	p.set(6,30),p.set(1,12);
	cout<<p.get(10);
	p.set(4,15),p.set(7,22),p.set(11,26),p.set(8,17),p.set(9,29);
	cout<<p.get(5);
	p.set(3,4),p.set(11,30);
	cout<<p.get(12);
	p.set(4,29);
	cout<<p.get(3);
	cout<<p.get(9);
	cout<<p.get(6);
	p.set(3,4);
	cout<<p.get(1);
	cout<<p.get(10);
	p.set(3,29),p.set(10,28),p.set(1,20),p.set(11,13);
	cout<<p.get(3);
	p.set(3,12),p.set(3,8),p.set(10,9),p.set(3,26);
	cout<<p.get(8);
	cout<<p.get(7);
	cout<<p.get(5);
	p.set(13,17),p.set(2,27),p.set(11,15);
	cout<<p.get(12);
	p.set(9,19),p.set(2,15),p.set(3,16);
	cout<<p.get(1);
	p.set(12,17),p.set(9,1),p.set(6,19);
	cout<<p.get(4);
	cout<<p.get(5);
	cout<<p.get(5);
	p.set(8,1),p.set(11,7),p.set(5,2),p.set(9,28);
	cout<<p.get(1);
	p.set(2,2),p.set(7,4),p.set(4,22),p.set(7,24),p.set(9,26),p.set(13,28),p.set(11,26);
	
	printf("\n");
}
     
int main(void)    
{    
    LARGE_INTEGER BegainTime ;    
    LARGE_INTEGER EndTime ;    
    LARGE_INTEGER Frequency ;    
    QueryPerformanceFrequency(&Frequency);    
    QueryPerformanceCounter(&BegainTime) ;    
    
    //要测试的代码放在这里   
    Test();   
     
    QueryPerformanceCounter(&EndTime);   
    
    //输出运行时间(单位:s)   
    cout << "运行时间(单位:s):" <<(double)( EndTime.QuadPart - BegainTime.QuadPart )/ Frequency.QuadPart <<endl;    
    
    //system("pause") ;    
    return 0 ;    
}

然后我发现,如果用双向循环队列来实现,就不需要对head和tail指针进行维护了,head的pre指针指向队尾,也就是最新添加的元素,head的next指针指向头结点,也就是最早添加的元素,因为使用循环链表,我也不用在队空的时候对head指针进行单独讨论了,因为head的前一个结点和后一个结点都是head本身,所以结点定义稍微修改下,有:

struct node{
	int key;
	int val;
	node *pre;
	node *next;
	node(int k,int v):key(k),val(v),pre(this),next(this){}
};

整个实现起来就轻松了许多,略作修改后的全部代码贴上:

#include <iostream>
#include <windows.h>
#include <random>
#include <stdlib.h>
#include <time.h>
#include <algorithm>
#include <vector>
#include <string>
#include <map>
#include <cmath>
using namespace std;

#define MAX 10000
#define MOD 9991
struct node{
	int key;
	int val;
	node *pre;
	node *next;
	node(int k,int v):key(k),val(v),pre(this),next(this){}
};


class DQueue{
	node *head;
	int Size;
	node *Hash[MAX][20];
	int len[MAX];
	int _hash(int key){
		return key%MOD;
	}
	node * search(int key){
		int pos=_hash(key);
		for(int i=0;i<len[pos];i++){
			if(key==Hash[pos][i]->key){
				node *t=Hash[pos][i];
				t->pre->next=t->next;
				t->next->pre=t->pre;
				
				t->pre=head->pre;
				head->pre->next=t;

				t->next=head;
				head->pre=t;
				return t;
			}
		}
		return NULL;
	}
public:
	DQueue(){
		head=new node(0,0);
		memset(len,0,sizeof(len));
		Size=0;
	}
	int insert(int key,int val){
		
		node *i=search(key);
		if(i!=NULL){
			i->val=val;
			return 1;
		}
		
		node *t=new node(key,val);
		
		t->next=head;
		t->pre=head->pre;
		head->pre->next=t;
		head->pre=t;

		int pos=_hash(key);
		Hash[pos][len[pos]++]=t;
		Size++;
		return 1;
	}
	int delHead(){
		node *p=head->next;
		head->next=p->next;
		p->next->pre=head;
		delete[] p;
		Size--;
		return 1;
	}
	int find(int key){
		node *f=search(key);
		if(f!=NULL)return f->val;
		return -1;
	}
	int size(){
		return Size;
	}
};

class LRUCache{
	DQueue *dq;
	int capa;
public:
	LRUCache(int capacity){
		capa=capacity;
		dq=new DQueue();
	}
	int get(int key){
		return dq->find(key);
	}
	void set(int key,int value){
		dq->insert(key,value);
		if(dq->size()>capa){
			dq->delHead();
		}
	}
};
void Test()//测试程序
{
	//freopen("C:\\in.txt","r",stdin);
	LRUCache p(10);
	p.set(10,13),p.set(3,17),p.set(6,11),p.set(10,5),p.set(9,10);
	cout<<p.get(13);
	p.set(2,19);
	cout<<p.get(2);
	cout<<p.get(3);
	p.set(5,25);
	cout<<p.get(8);
	p.set(9,22),p.set(5,5),p.set(1,30);
	cout<<p.get(11);
	p.set(9,12);
	cout<<p.get(7);
	cout<<p.get(5);
	cout<<p.get(8);
	cout<<p.get(9);
	p.set(4,30),p.set(9,3);
	cout<<p.get(9);
	cout<<p.get(10);
	cout<<p.get(10);
	p.set(6,14);p.set(3,1);
	cout<<p.get(3);
	p.set(10,11);
	cout<<p.get(8);
	p.set(2,14);
	cout<<p.get(1);
	cout<<p.get(5);
	cout<<p.get(4);
	p.set(11,4),p.set(12,24),p.set(5,18);
	cout<<p.get(13);
	p.set(7,23);
	cout<<p.get(8);
	cout<<p.get(12);
	p.set(3,27),p.set(2,12);
	cout<<p.get(5);
	p.set(2,9),p.set(13,4);p.set(8,18),p.set(1,7);
	cout<<p.get(6);
	p.set(9,29),p.set(8,21);
	cout<<p.get(5);
	p.set(6,30),p.set(1,12);
	cout<<p.get(10);
	p.set(4,15),p.set(7,22),p.set(11,26),p.set(8,17),p.set(9,29);
	cout<<p.get(5);
	p.set(3,4),p.set(11,30);
	cout<<p.get(12);
	p.set(4,29);
	cout<<p.get(3);
	cout<<p.get(9);
	cout<<p.get(6);
	p.set(3,4);
	cout<<p.get(1);
	cout<<p.get(10);
	p.set(3,29),p.set(10,28),p.set(1,20),p.set(11,13);
	cout<<p.get(3);
	p.set(3,12),p.set(3,8),p.set(10,9),p.set(3,26);
	cout<<p.get(8);
	cout<<p.get(7);
	cout<<p.get(5);
	p.set(13,17),p.set(2,27),p.set(11,15);
	cout<<p.get(12);
	p.set(9,19),p.set(2,15),p.set(3,16);
	cout<<p.get(1);
	p.set(12,17),p.set(9,1),p.set(6,19);
	cout<<p.get(4);
	cout<<p.get(5);
	cout<<p.get(5);
	p.set(8,1),p.set(11,7),p.set(5,2),p.set(9,28);
	cout<<p.get(1);
	p.set(2,2),p.set(7,4),p.set(4,22),p.set(7,24),p.set(9,26),p.set(13,28),p.set(11,26);
	
	printf("\n");
}
     
int main(void)    
{    
    LARGE_INTEGER BegainTime ;    
    LARGE_INTEGER EndTime ;    
    LARGE_INTEGER Frequency ;    
    QueryPerformanceFrequency(&Frequency);    
    QueryPerformanceCounter(&BegainTime) ;    
    
    //要测试的代码放在这里   
    Test();   
     
    QueryPerformanceCounter(&EndTime);   
    
    //输出运行时间(单位:s)   
    cout << "运行时间(单位:s):" <<(double)( EndTime.QuadPart - BegainTime.QuadPart )/ Frequency.QuadPart <<endl;    
    
    //system("pause") ;    
    return 0 ;    
}

第一篇博客就这么写完了,写的不好还请见谅

你可能感兴趣的:(LRU,hash,循环队列,双向队列)