Redis实现分布式读写锁(Java基于Lua实现)

Redis实现分布式读写锁

前言

使用Jedis构建redis连接池,使用lua脚本命令保证redis的事务,以实现分布式的读写锁。项目中需要用到分布式的读写锁,开始使用Redisson的读写锁实现,压测的时候时不时会抛异常获取锁超时,初步判断是Redisson中redis连接池设置的太小。由于项目中还自己另外维护着一个redis的连接池JedisPool,故决定自己来实现分布式的可重入读写锁。

设计思路

目标:读锁可以被多个线程获取,同一个线程可以重入读锁,如果获取读锁的时候存在写锁,则需要等待写锁被释放才能获取;写锁的获取的时候,若存在读锁,则需要等待所有的读锁释放之后才能被获取,如果有线程正在获取写锁,其他获取读锁的线程将等待写锁被获取并释放之后才能获取读锁,写锁只能被一个线程持有,可以重入。

方案
假设我们要操作“abc”这个读写锁

  1. 一个读锁在redis中的存在形式为一个hash结构,key为read_lock_abc,值为键值对组成的hash,键值对的键为thread_id(这个id由连接池id+获取锁的线程id,来区分分布式中的不同线程),值为该线程重入读锁的次数。
  2. 一个写锁分为两部分,一个是key为write_lock_abc,值为thread_id(同上);另外一个key为reentrant_write_lock_abc,值为被重入的次数。

实现

RedisReadWriteLock.java 该类用于维护读写锁的单例

public class RedisReadWriteLock {
    //读锁
    private static volatile RedisReadLock redisReadLock;
    //写锁
    private static volatile RedisWriteLock redisWriteLock;

    //双重检查锁实现单例
    public static RedisReadLock readLock(){
        if(redisReadLock == null){
            synchronized (RedisReadLock.class){
                if (redisReadLock == null){
                    redisReadLock = new RedisReadLock();
                }
            }
        }
        return redisReadLock;
    }

    public static RedisWriteLock writeLock(){
        if(redisWriteLock == null){
            synchronized (RedisWriteLock.class){
                if (redisWriteLock == null){
                    redisWriteLock = new RedisWriteLock();
                }
            }
        }
        return redisWriteLock;
    }

    // 构建锁的key
    public static String getReadLockKey(String name){
        return RedisLockConf.READ_LOCK_PREFIX + name;
    }

    public static String getWriteLockKey(String name){
        return RedisLockConf.WRITE_LOCK_PREFIX + name;
    }

    public static String getReentrantWriteLockKey(String name){
        return RedisLockConf.REENTRANT_WRITE_LOCK_PREFIX + name;
    }

    //由连接池id+获取锁的线程id,来区分分布式中的不同线程
    public static String getThreadUid(){
        return JedisConnectPoll.JEDIS_CONNECT_POLL_UUID.toString() + ":" + Thread.currentThread().getId();
    }
}


public class RedisLockConf {
    public static final String READ_LOCK_PREFIX = "read_lock_";
    public static final String WRITE_LOCK_PREFIX = "write_lock_";
    public static final String REENTRANT_WRITE_LOCK_PREFIX = "reentrant_write_lock_";
}

RedisReadLock.java 读锁的实现

@Slf4j
public class RedisReadLock {
    public void lock(String name){
        tryLock(name, Long.MAX_VALUE, 30, TimeUnit.SECONDS);
    }
    
    public void lock(String name, long leaseTime, TimeUnit unit){
        tryLock(name, Long.MAX_VALUE, leaseTime, unit);
    }

    public boolean tryLock(String name, long waitTime, long leaseTime, TimeUnit unit){
        Long waitUntilTime = unit.toMillis(waitTime) + System.currentTimeMillis();
        if(waitUntilTime < 0){
            waitUntilTime = Long.MAX_VALUE;
        }
        Long leastTimeLong = unit.toMillis(leaseTime);
        StringBuilder sctipt = new StringBuilder();
        
        // write-lock read-lock uuid leaseTime,后面会专门说这段脚本
        sctipt.append("if not redis.call('GET',KEYS[1]) then ")
                     //redis.call('GET',KEYS[1])之类的命令,若没有值返回的布尔类型的false,不是nil
                    .append("local count = redis.call('HGET',KEYS[2],KEYS[3]);")
                    .append("if count then ")
                        .append("count = tonumber(count) + 1;")
                        .append("redis.call('HSET',KEYS[2],KEYS[3],count);")
                    .append("else ")
                        .append("redis.call('HSET',KEYS[2],KEYS[3],1);")
                    .append("end;")
                    .append("local t = redis.call('PTTL', KEYS[2]);")
                    .append("redis.call('PEXPIRE', KEYS[2], math.max(t, ARGV[1]));")
                    .append("return 1;")
                .append("else ")
                     .append("return 0;")
                .append("end;");
        for(;;){
            if(System.currentTimeMillis() > waitUntilTime){
                return false;
            }
            Long res = (Long) JedisTemplate.operate().eval(sctipt.toString(), 3, RedisReadWriteLock.getWriteLockKey(name), RedisReadWriteLock.getReadLockKey(name), RedisReadWriteLock.getThreadUid(), leastTimeLong.toString());
            if(res.equals(1L)){
                //successGetReadLock
                log.debug("success get read lock,  readLock={}", RedisReadWriteLock.getReadLockKey(name));
                break;
            }else {
                //need to wait write lock to be released
                log.debug("wait write lock release,  writeLock={}", RedisReadWriteLock.getWriteLockKey(name));
                try {
                    TimeUnit.MILLISECONDS.sleep(50);
                } catch (InterruptedException e) {
                    log.error("wait write lock release exception", e);
                }
            }
        }
        return true;
    }

    public void unlock(String name){
        StringBuilder sctipt = new StringBuilder();
        sctipt.append("local count = redis.call('HGET',KEYS[1],KEYS[2]);")
                .append("if count then ")
                    .append("if (tonumber(count) > 1) then ")
                        .append("count = tonumber(count) - 1;")
                        .append("redis.call('HSET',KEYS[1],KEYS[2],count);")
                    .append("else ")
                      .append("redis.call('HDEL',KEYS[1],KEYS[2]);")
                    .append("end;")
                .append("end;")
                .append("return;");
        JedisTemplate.operate().eval(sctipt.toString(), 2, RedisReadWriteLock.getReadLockKey(name), RedisReadWriteLock.getThreadUid());
        log.debug("success unlock read lock, readLock={}", RedisReadWriteLock.getReadLockKey(name));
    }

}

redis执行lua脚本的命令格式为:EVAL script numkeys key [key …] arg [arg …]

redis 127.0.0.1:6379> EVAL "return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}" 2 key1 key2 first second

1) "key1"
2) "key2"
3) "first"
4) "second"
-- 读锁获取的lua脚本
-- 判断不存在写锁
if not redis.call('GET',KEYS[1]) then 
    local count = redis.call('HGET',KEYS[2],KEYS[3])
    -- 如果该线程已经获取了该读锁,就重入,重入次数加1
    if count then 
        count = tonumber(count) + 1
        redis.call('HSET',KEYS[2],KEYS[3],count)
    else 
        redis.call('HSET',KEYS[2],KEYS[3],1)
    end
    -- 检查之前读锁的过期时间,和当前加的读锁的过期时间做对比,更新过期时间
    local t = redis.call('PTTL', KEYS[2])
    redis.call('PEXPIRE', KEYS[2], math.max(t, ARGV[1]))
    return 1
else 
    -- 若存在写锁,返回获取失败,外层的代码做轮询尝试加锁
    return 0
end;


-- 读锁释放的lua脚本
-- 获取锁被当前线程重入的次数
local count = redis.call('HGET',KEYS[1],KEYS[2])
if count then 
    if (tonumber(count) > 1) then 
        count = tonumber(count) - 1
        redis.call('HSET',KEYS[1],KEYS[2],count)
    else 
        redis.call('HDEL',KEYS[1],KEYS[2])
    end
end
return

RedisWriteLock.java 写锁的实现

@Slf4j
public class RedisWriteLock {
    public void lock(String name){
        tryLock(name, Long.MAX_VALUE, 30, TimeUnit.SECONDS);
    }
    
    public void lock(String name, long leaseTime, TimeUnit unit){
        tryLock(name, Long.MAX_VALUE, leaseTime, unit);
    }

    public boolean tryLock(String name, long waitTime, long leaseTime, TimeUnit unit){
        Long waitUntilTime = unit.toMillis(waitTime) + System.currentTimeMillis();
        if(waitUntilTime < 0){
            waitUntilTime = Long.MAX_VALUE;
        }
        Long leastTimeLong = unit.toMillis(leaseTime);
        StringBuilder sctipt = new StringBuilder();
        // write-lock reentrant-write-lock uuid leaseTime
        sctipt.append("if redis.call('SET',KEYS[1],ARGV[1],'NX','PX',ARGV[2]) then ")
                .append("redis.call('SET',KEYS[2],1,'PX',ARGV[2]);")
                .append("return 1;")
                .append("else ")
                    .append("if (redis.call('GET',KEYS[1])== ARGV[1]) then ")
                        .append("local count = redis.call('GET',KEYS[2]);")
                        .append("if not count then ")
                            .append("redis.call('SET',KEYS[2],1,'PX',ARGV[2]);")
                            .append("return 1;")
                        .append("else ")
                            .append("count = tonumber(count) + 1;")
                            .append("redis.call('SET',KEYS[2],count,'PX',ARGV[2]);")
                            .append("return count;")
                        .append("end;")
                    .append("else ")
                        .append("return 0;")
                    .append("end;")
                .append("end;");
        for(;;){
            if(System.currentTimeMillis() > waitUntilTime){
                return false;
            }
            Long res = (Long) JedisTemplate.operate().eval(sctipt.toString(), 2, RedisReadWriteLock.getWriteLockKey(name), RedisReadWriteLock.getReentrantWriteLockKey(name), RedisReadWriteLock.getThreadUid(), leastTimeLong.toString());
            if(res.equals(1L)){
                //successGetWriteLock
                log.debug("success get write lock,  writeLock = {}", RedisReadWriteLock.getWriteLockKey(name));
                for(;;){
                    if(JedisTemplate.operate().exists(RedisReadWriteLock.getReadLockKey(name))){
                        log.debug("wait read lock release,  readLock = {}", RedisReadWriteLock.getReadLockKey(name));
                        try {
                            TimeUnit.MILLISECONDS.sleep(100);
                        } catch (InterruptedException e) {
                            log.error("wait read lock release exception", e);
                        }
                    }else{
                        break;
                    }
                }
                break;
            }else if(res.equals(0L)){
                //need to wait write lock to be released
                log.debug("wait write lock release,  writeLock = {}", RedisReadWriteLock.getWriteLockKey(name));
                try {
                    TimeUnit.MILLISECONDS.sleep(100);
                } catch (InterruptedException e) {
                    log.error("wait write lock release exception", e);
                }
            }else{
                log.debug("success in reentrant write lock,  reentrantWriteLock = {}, count now = {}", RedisReadWriteLock.getReentrantWriteLockKey(name), res);
                break;
            }
        }
        return true;
    }

    public void unlock(String name){
        StringBuilder sctipt = new StringBuilder();
        //write-lock reentrant-write-lock uuid
        sctipt.append("if (redis.call('GET',KEYS[1])== ARGV[1]) then ")
                    .append("local count = redis.call('GET',KEYS[2]);")
                    .append("if count then ")
                        .append("if (tonumber(count) > 1) then ")
                            .append("count = tonumber(count) - 1;")
                            .append("local live = redis.call('PTTL',KEYS[2]);")
                            .append("redis.call('SET',KEYS[2],count,'PX',live);")
                            //success unlock reentrant-write-lock
                            .append("return count;")
                        .append("else ")
                            .append("redis.call('DEL',KEYS[2]);")
                            .append("redis.call('DEL',KEYS[1]);")
                            //success unlock
                            .append("return 0;")
                        .append("end;")
                    .append("else ")
                        .append("redis.call('DEL',KEYS[1]);")
                        .append("return 0;")
                    .append("end;")
                .append("else ")
                    //fail unlock, thread not get the lock
                    .append("return -1;")
                .append("end;");
        Long res = (Long) JedisTemplate.operate().eval(sctipt.toString(), 2, RedisReadWriteLock.getWriteLockKey(name), RedisReadWriteLock.getReentrantWriteLockKey(name), RedisReadWriteLock.getThreadUid());
        if(res.equals(0L)){
            log.debug("success unlock write lock,  writeLock = {}", RedisReadWriteLock.getWriteLockKey(name));
        }else if(res.equals(-1L)){
            log.debug("fail unlock, thread not get the lock,  writeLock = {}, thread = {}", RedisReadWriteLock.getReentrantWriteLockKey(name), RedisReadWriteLock.getThreadUid());
        }else {
            log.debug("success unlock reentrant write lock,  reentrantWriteLock = {}, count left = {}", RedisReadWriteLock.getReentrantWriteLockKey(name), res);
        }
    }
}
-- 写锁获取的lua脚本 
-- 没有写锁,则成功设置
if redis.call('SET',KEYS[1],ARGV[1],'NX','PX',ARGV[2]) then 
    -- 重入数记为1
    redis.call('SET',KEYS[2],1,'PX',ARGV[2])
    return 1
else 
    -- 写锁已经被获取,判断已经获取写锁的线程是不是当前线程
    if (redis.call('GET',KEYS[1])== ARGV[1]) then 
        -- 若是当前线程,则重入数+1
        local count = redis.call('GET',KEYS[2])
        if not count then 
            redis.call('SET',KEYS[2],1,'PX',ARGV[2])
            return 1;
        else 
            count = tonumber(count) + 1;
            redis.call('SET',KEYS[2],count,'PX',ARGV[2])
            return count
        end
    else 
        return 0
    end
end

-- 写锁释放的lua脚本 
-- 判断是否是当前线程获取的写锁
if (redis.call('GET',KEYS[1])== ARGV[1]) then 
    -- 是当前线程获取的写锁,判断锁是否被重入,重入数减到1后释放锁
    local count = redis.call('GET',KEYS[2])
    if count then 
        if (tonumber(count) > 1) then 
            count = tonumber(count) - 1
            local live = redis.call('PTTL',KEYS[2])
            redis.call('SET',KEYS[2],count,'PX',live)
            -- 返回该线程对该锁剩余的重入次数
            return count
        else 
            redis.call('DEL',KEYS[2])
            redis.call('DEL',KEYS[1])
            return 0
        end
    else
        redis.call('DEL',KEYS[1])
        return 0
    end
else 
    -- 其他线程获取的该写锁,该线程解锁失败
    return -1
end

测试如下:

public static void main(String[] args) {
    final int[] num = {0};
    for (int i = 0; i < 10 ; i++) {
        Thread thread = new Thread(() -> {
            RedisReadWriteLock.writeLock().tryLock("ccc", 30,300, TimeUnit.SECONDS);
            num[0]++;
            System.out.println("【写】:" + num[0]);
            RedisReadWriteLock.writeLock().unlock("ccc");
        });
        thread.start();
    }
    for (int i = 0; i < 100 ; i++) {
        Thread thread = new Thread(() -> {
            RedisReadWriteLock.readLock().tryLock("ccc", 30,300, TimeUnit.SECONDS);

            System.out.println("读:" + num[0]);

            RedisReadWriteLock.readLock().unlock("ccc");
        });
        thread.start();
        if(i % 3 == 0){
            try {
                TimeUnit.MILLISECONDS.sleep(50);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
}

运行结果

【写】:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
读:1
【写】:2
【写】:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
读:3
【写】:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
读:4
【写】:5
读:5
读:5
读:5
读:5
读:5
读:5
【写】:6
读:6
读:6
读:6
读:6
读:6
读:6
【写】:7
读:7
读:7
读:7
读:7
读:7
读:7
读:7
读:7
读:7
【写】:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
读:8
【写】:9
读:9
读:9
读:9
读:9
读:9
读:9
【写】:10
读:10
读:10
读:10

JedisTemplate其实是对从JedisPool获取的Jedis连接做的一个cglib的代理,用于使用完之后自动调用Jedis实例的close方法将连接归还到JedisPool。代理肯定会存在一点性能损耗,但是也简化了编码,以及避免了忘记归还连接,导致池的连接被耗空。

需要引入cglib的依赖

<dependency>
    <groupId>cglibgroupId>
    <artifactId>cglibartifactId>
    <version>3.2.10version>
dependency>
public class JedisTemplate {
    public static Jedis operate(){
        Enhancer enhancer = new Enhancer();
        enhancer.setSuperclass(Jedis.class);
        enhancer.setCallback(new JedisCglibProxyIntercepter());
        return (Jedis) enhancer.create();
    }
}

public class JedisCglibProxyIntercepter implements MethodInterceptor {
    @Override
    public Object intercept(Object o, Method method, Object[] objects, MethodProxy methodProxy) throws Throwable {
        //try后会自动调用jedis的close方法释放资源
        try(Jedis jedis = JedisConnectPoll.getJedis()){
            return method.invoke(jedis, objects);
        }
    }
}

JedisPool连接池的代码也贴一下:

需要添加下面的Jedis依赖

<dependency>
    <groupId>redis.clientsgroupId>
    <artifactId>jedisartifactId>
    <version>3.0.1version>
dependency>

@Slf4j的注解是需要添加lombok依赖,然后还要对应的日志依赖,这个按个人需要选择吧

<dependency>
    <groupId>org.projectlombokgroupId>
    <artifactId>lombokartifactId>
    <version>1.18.6version>
    <scope>providedscope>
dependency>

 
<dependency>
    <groupId>org.slf4jgroupId>
    <artifactId>slf4j-apiartifactId>
    <version>1.7.16version>
dependency>

<dependency>
    <groupId>org.slf4jgroupId>
    <artifactId>jcl-over-slf4jartifactId>
    <version>1.7.16version>
dependency>

<dependency>
    <groupId>org.slf4jgroupId>
    <artifactId>log4j-over-slf4jartifactId>
    <version>1.7.16version>
dependency>

<dependency>
    <groupId>ch.qos.logbackgroupId>
    <artifactId>logback-coreartifactId>
    <version>1.1.6version>
dependency>

<dependency>
    <groupId>ch.qos.logbackgroupId>
    <artifactId>logback-classicartifactId>
    <version>1.1.6version>
    <exclusions>
        <exclusion>
            <artifactId>slf4j-apiartifactId>
            <groupId>org.slf4jgroupId>
        exclusion>
    exclusions>
dependency>

@Slf4j
public class JedisConnectPoll {
    public static final UUID JEDIS_CONNECT_POLL_UUID = UUID.randomUUID();
    //连接redis实例的ip
    private static final String REDIS_ADDRESS = "127.0.0.1";
    //连接redis实例的端口
    private static final int PORT = "6379";
    //密码
    private static final String PASSWORD = "";
    //多线程环境中,连接实例的最大数,如果设为-1则无上线,建议设置,否则有可能导致资源耗尽
    private static final int MAX_ACTIVE = 160;
    //在多线程环境中,连接池中最大空闲连接数,单线程环境没有实际意义
    private static final int MAX_OLDE = 128;
    //在多线程环境中,连接池中最小空闲连接数
    private static final int MIN_OLDE = 8;
    //多长时间将空闲线程进行回收,单位毫秒
    private static final int METM = 2000;
    //对象空闲多久后逐出, 当空闲时间>该值 且 空闲连接>最大空闲数 时直接逐出,不再根据MinEvictableIdleTimeMillis判断 (默认逐出策略)
    private static final int SMETM = 2000;
    //逐出扫描的时间间隔(毫秒) 如果为负数,则不运行逐出线程, 默认-1,只有运行了此线程,MIN_OLDE METM/SMETM才会起作用
    private static final int TBERM = 1000;
    //当连接池中连接不够用时,等待可用连接的最大时间,单位毫秒,默认值为-1,表示永不超时。如果超过等待时间,则直接抛出JedisConnectionException;
    private static final int MAX_WAIT = 1000;
    //超时时间,单位毫秒
    private static final int TIME_OUT = 10000;
    //在借用一个jedis连接实例时,是否提前进行有效性确认操作;如果为true,则得到的jedis实例均是可用的;
    private static final boolean TEST_ON_BORROW = false;

    //连接池实例
    private static JedisPool jedisPool = null;

    static {
        initPoll();
    }

    private static void initPoll() {
        try {
            JedisPoolConfig config = new JedisPoolConfig();
            config.setMaxTotal(MAX_ACTIVE);
            config.setMaxIdle(MAX_OLDE);
            config.setMaxWaitMillis(MAX_WAIT);
            config.setTestOnBorrow(TEST_ON_BORROW);
            config.setMinIdle(MIN_OLDE);
            config.setMinEvictableIdleTimeMillis(METM);
            config.setSoftMinEvictableIdleTimeMillis(SMETM);
            config.setTimeBetweenEvictionRunsMillis(TBERM);

            if(!"".equals(PASSWORD)){
                jedisPool = new JedisPool(config, REDIS_ADDRESS, PORT, TIME_OUT, PASSWORD);
            }else {
                jedisPool = new JedisPool(config, REDIS_ADDRESS, PORT, TIME_OUT);
            }

        } catch (Exception e) {
            log.error("initial JedisPoll fail: {}",e);
        }
    }

    public static Jedis getJedis(){
        return jedisPool.getResource();
    }
}

你可能感兴趣的:(Redis实现分布式读写锁(Java基于Lua实现))