Java版的Redis


Redis是一个基于Key-value结构的Nosql数据库,它支持各种常见的数据结构以及非常方便的操作,
与其说它是一个数据库,不如说它是一个保存各种数据结构的服务器。今天闲来没事,用 Java集合类
实现了Redis的一些基本功能,算是温习下Java了。

1.Redis入门

Redis的Key键值为字符串,但是Value值支持许多种类型,如String字符串,List链表,Set无序集合,
SortedSet有序集合,甚至是Hash表。

各种数据结构通过不同的存取方法来区分。如Set/Get直接将值存为String,LPush/LPop/LRange将
值存到一个链表中,SAdd/ZAdd则区分了无序和有序集合。

下面我们来看下在Java中使用基本的集合类如何实现这些简单而方便的操作。


2.Java版的Redis

代码的组织结构如下图:

Java版的Redis

package com.cdai.studio.redis;

import java.util.HashSet;
import java.util.LinkedList;
import java.util.TreeSet;

@SuppressWarnings("unchecked")
public class RedisDB {

	private Persistence persistence = new Persistence();
	
	private Serializer serializer = new Serializer();
	
	private static final Object[] NULL = new Object[0];
	
	
	// =================================================
	//					String value
	// =================================================
	
	public void Set(String key, Object value) {
		persistence.put(key, serializer.marshal(value));
	}
	
	public Object Get(String key) {
		return serializer.unmarshal(persistence.get(key));
	}
	
	public Object[] MGet(String... keys) {
		Object[] values = new Object[keys.length];
		for (int i = 0; i < keys.length; i++)
			values[i] = Get(keys[i]);
		return values;
	}
	
	public int Incr(String key) {
		Object value = Get(key);
		Integer valueRef = (value == null) ? 1 : (Integer) value;
		Set(key, valueRef + 1);
		return valueRef;
	}
	
	
	// =================================================
	//					List value
	// =================================================

	public void LPush(String key, Object... values) {
		Object list = persistence.get(key);
		if (list == null)
			list = new LinkedList<Object>();
		else
			list = serializer.unmarshal(list);
		
		LinkedList<Object> listRef = (LinkedList<Object>) list;
		for (Object value : values)
			listRef.addFirst(value);
		persistence.put(key, serializer.marshal(list));
	}
	
	public void RPush(String key, Object... values) {
		Object list = persistence.get(key);
		if (list == null)
			list = new LinkedList<Object>();
		else
			list = serializer.unmarshal(list);
		
		LinkedList<Object> listRef = (LinkedList<Object>) list;
		for (Object value : values)
			listRef.addLast(value);
		persistence.put(key, serializer.marshal(list));
	}
	
	public Object[] LRange(String key, int start, int end) {
		Object list = persistence.get(key);
		if (list == null)
			return NULL;
		
		LinkedList<Object> listRef = (LinkedList<Object>) serializer.unmarshal(list);
		if (end > listRef.size())
			end = listRef.size();
		return listRef.subList(start, end).toArray();
	}
	
	
	// =================================================
	//					Unsorted Set value
	// =================================================

	public void SAdd(String key, Object... values) {
		Object set = persistence.get(key);
		if (set == null)
			set = new HashSet<Object>();
		else
			set = serializer.unmarshal(set);
		
		HashSet<Object> setRef = (HashSet<Object>) set;
		for (Object value : values)
			setRef.add(value);
		persistence.put(key, serializer.marshal(set));
	}
	
	public Object[] SMembers(String key) {
		Object set = persistence.get(key);
		if (set == null)
			return NULL;
		
		set = serializer.unmarshal(set);
		return ((HashSet<Object>) set).toArray();
	}
	
	public Object[] SInter(String key1, String key2) {
		Object set1 = persistence.get(key1);
		Object set2 = persistence.get(key2);
		if (set1 == null || set2 == null)
			return NULL;
		
		HashSet<Object> set1Ref = (HashSet<Object>) serializer.unmarshal(set1);
		HashSet<Object> set2Ref = (HashSet<Object>) serializer.unmarshal(set2);
		set1Ref.retainAll(set2Ref);
		return set1Ref.toArray();
	}
	
	public Object[] SDiff(String key1, String key2) {
		Object set1 = persistence.get(key1);
		Object set2 = persistence.get(key2);
		if (set1 == null || set2 == null)
			return NULL;
		
		HashSet<Object> set1Ref = (HashSet<Object>) serializer.unmarshal(set1);
		HashSet<Object> set2Ref = (HashSet<Object>) serializer.unmarshal(set2);
		set1Ref.removeAll(set2Ref);
		return set1Ref.toArray();
	}
	
	
	// =================================================
	//					Sorted Set value
	// =================================================

	public void ZAdd(String key, Object... values) {
		Object set = persistence.get(key);
		if (set == null)
			set = new TreeSet<Object>();
		else
			set = serializer.unmarshal(set);
		
		TreeSet<Object> setRef = (TreeSet<Object>) set;
		for (Object value : values)
			setRef.add(value);
		persistence.put(key, serializer.marshal(set));
	}
	
	public Object[] SRange(String key, Object from) {
		Object set = persistence.get(key);
		if (set == null)
			return NULL;
		
		set = serializer.unmarshal(set);
		return ((TreeSet<Object>) set).tailSet(from).toArray();
	}
	
}
package com.cdai.studio.redis;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;

class Serializer {

	Object marshal(Object object) {
		if (object == null)
			return null;
		return new BytesWrapper((Serializable) object);
	}
	
	Object unmarshal(Object object) {
		if (object == null)
			return null;
		return ((BytesWrapper) object).readObject();
	}
	
}


class BytesWrapper {
	
	private byte[] bytes;
	
	<T extends Serializable> BytesWrapper(T object) {
		writeBytes(object);
	}
	
	<T extends Serializable> void writeBytes(T object) {
		try {
			ByteArrayOutputStream buffer = new ByteArrayOutputStream();
			ObjectOutputStream output = new ObjectOutputStream(buffer);
			output.writeObject(object);
			output.flush();
			bytes = buffer.toByteArray();
			output.close();
		}
		catch (IOException e) {
			e.printStackTrace();
			throw new IllegalStateException(e);
		}
	}
	
	Object readObject() {
		try {
			ObjectInputStream input = new ObjectInputStream(new ByteArrayInputStream(bytes));
			Object object = input.readObject();
			input.close();
			return object;
		}
		catch (Exception e) {
			e.printStackTrace();
			throw new IllegalStateException(e);
		}
	}

	@Override
	public int hashCode() {
		final int prime = 31;
		int result = 1;
		result = prime * result + Arrays.hashCode(bytes);
		return result;
	}

	@Override
	public boolean equals(Object obj) {
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		BytesWrapper other = (BytesWrapper) obj;
		if (!Arrays.equals(bytes, other.bytes))
			return false;
		return true;
	}
	
}
package com.cdai.studio.redis;

import java.util.HashMap;

class Persistence {

	private HashMap<String, Object> storage =
		new HashMap<String, Object>();
	
	
	void put(String key, Object value) {
		storage.put(key, value);
	}
	
	Object get(String key) {
		return storage.get(key);
	}
	
}


3.简单的客户端
package com.cdai.studio.redis;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.List;

public class RedisServer {

	private RedisDB redis;
	
	public RedisServer(RedisDB redis) {
		this.redis = redis;
	}
	
	@SuppressWarnings("unchecked")
	public void start() {
		ServerSocket serverSocket = null;
		try {
			serverSocket = new ServerSocket(1234);
			while (true) {
				Socket socket = serverSocket.accept();
				
				ObjectInputStream input = new ObjectInputStream(socket.getInputStream());
				List<Object> request = (List<Object>) input.readObject();
				
				Object response = null;
				if ("Set".equals(request.get(0))) {
					redis.Set((String) request.get(1), request.get(2));
				}
				else if ("Get".equals(request.get(0))) {
					response = redis.Get((String) request.get(1));
				}
				
				ObjectOutputStream output = new ObjectOutputStream(socket.getOutputStream());
				output.writeObject(response);
				
				input.close();
				output.close();
				socket.close();
			}
		} 
		catch (Exception e) {
			e.printStackTrace();
		}
		finally {
			if (serverSocket != null) {
				try {
					serverSocket.close();
				} catch (IOException e) {
				}
			}
		}
		
	}
	
}
package com.cdai.studio.redis;

import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.Socket;
import java.util.Arrays;
import java.util.List;

public class RedisClient {

	public <T extends Serializable> void Set(String key, Object value) {
		sendRequest(Arrays.asList("Set", key, value));
	}
	
	public Object Get(String key) {
		return sendRequest(Arrays.<Object>asList("Get", key));
	}
	
	private Object sendRequest(List<Object> payload) {
		Socket socket = null;
		try {
			socket = new Socket("localhost", 1234);
			
			ObjectOutputStream output = new ObjectOutputStream(socket.getOutputStream());
			output.writeObject(payload);
			output.flush();
			
			ObjectInputStream input = new ObjectInputStream(socket.getInputStream());
			Object response = input.readObject();
			
			output.close();
			input.close();
			return response;
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			if (socket != null) {
				try {
					socket.close();
				} catch (Exception e) {
				}
			}
		}
		return null;
	}
	
}


4.实现简单的Twitter
package com.cdai.studio.redis;

import java.util.Arrays;

public class RedisTest {

	public static void main(String[] args) {
		
		RedisDB redis = new RedisDB();
		
		// 1.Create user follow relationship
		redis.SAdd("users", "A", "B", "C");
		
		// User A follows B, C
		redis.SAdd("users:A:following", "B", "C");
		redis.SAdd("users:B:followers", "A");
		redis.SAdd("users:C:followers", "A");
		
		// User C follows B 
		redis.SAdd("users:C:following", "B");
		redis.SAdd("users:B:followers", "C");
		
		
		// 2.1 B send tweet
		int tid = redis.Incr("tweets:next_id");
		redis.Set("tweets:" + tid, "B publish hello");
		redis.LPush("global:timeline", tid);
		redis.LPush("users:B:timeline", tid);
		for (Object follower : redis.SMembers("users:B:followers"))
			redis.LPush("users:" + follower + ":timeline", tid);
		
		// 2.2 C send tweet 
		tid = redis.Incr("tweets:next_id");
		redis.Set("tweets:" + tid, "C publish world");
		redis.LPush("global:timeline", tid);
		redis.LPush("users:C:timeline", tid);
		for (Object follower : redis.SMembers("users:C:followers"))
			redis.LPush("users:" + follower + ":timeline", tid);
				
		
		Object[] tids = redis.LRange("global:timeline", 0, 9);
		String[] tweetids = new String[tids.length];
		for (int i = 0; i < tids.length; i++)
			tweetids[i] = "tweets:" + tids[i];
		System.out.println(Arrays.toString(redis.MGet(tweetids)));
	}

}


5.需要注意的问题

byte数组的equals和hashcode默认实现比较对象地址的,要借助于Arrays的equals和hashcode方法。

String字符串序列化和反序列化时要注意编码格式的问题,编码解码时应该使用相同的编码。

HashSet上的操作,removeAll补集,retainAll交集,addAll并集。


6.更加强大的Redis

Redis自己实现了各种数据结构,可以非常方便地增删改查,并且效率很高。这里我们只是用
Java来简单的学习了下Redis基本功能,其实Redis还支持很多其他的高级功能, 如消息订阅、
数据过期设置、事务、数据持久化。想要进一步学习的话可以试着用Java实现它们。


你可能感兴趣的:(redis)