Spark批量读写Redis

Spark批量读写Redis

需要新加入Redis的依赖

        <dependency>
            <groupId>redis.clients</groupId>
            <artifactId>jedis</artifactId>
            <version>3.0.1</version>
        </dependency>

连接客户端

首先需要拿到Redis的客户端,Redis的客户端需要知道你的用户名和密码、ip地址和端口号,知道就可以连上Redis了

import redis.clients.jedis.Jedis

object GetRedisClient {
  def getRedisClient(redisUgi:String):Jedis = {
    val port = -1 // 你的端口号
    val ip = "ip1"  // 你的具体redis的ip地址
    val (redisHost, redisPort) = getConfig()
    val redisClient = new Jedis(ip , port )
    redisClient.auth(redisUgi)
    redisClient
  }
}

批量删除的问题

批量查询与插入都用到了Jedis对象的pipelined()方法,利用管道一次性发送多条命令,节省不停发送命令的网络带宽,这个方法返回的是个redis.clients.jedis.Pipeline对象,目前的版本有一个问题,通过多层继承和实现接口,这个类继承的多个父类中不同父类分别实现了接口MultiKeyCommandsPipeline和接口MultiKeyBinaryRedisPipeline的del方法,导致Pipline无法调用del方法,因为编译器不知道该调哪个接口的del方法,所以使用Pipline的del方法时编译器就会报如下错误

[ERROR] E:\XXX.scala:153: ambiguous reference to overloaded definition,
both method del in class MultiKeyPipelineBase of type (x$1: String*)redis.clients.jedis.Response[Long]
and  method del in class PipelineBase of type (x$1: String)redis.clients.jedis.Response[Long]
match argument types (String)
[ERROR]       pipeline.del("")

由于其父类实现的接口中有重名的方法,导致Pipline无法使用del来批量删除,这个在最新的版本(3.0.1)存在,在2.0.0版本中不存在,但是老版本bug特别多,Spark并行插入的时候经常有问题,不建议使用

批量插入和查询以及并行删除

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.joda.time.DateTime
import org.joda.time.format.DateTimeFormat
import redis.clients.jedis.Response

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer


object RedisTest {
  var inputPath = ""
  var redisUgi = "" // redis认证信息
  var keyList = ""
  var mode = ""

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("RedisTest")
      .getOrCreate()

    init(args)

    if ("delete".equals(mode)) {
      delData(spark)
    } else if ("insert".equals(mode)) {
      putData(spark)
    } else {
      val redisUgiBc = spark.sparkContext.broadcast(redisUgi)
      val jedisWrite = GetRedisClient.getRedisClient(redisUgiBc.value)

      val arr = keyList.split(",")
      arr.foreach(x => println(x + ":" + jedisWrite.hget(x, "field_name")))
    }
  }


  def init(args: Array[String]): Unit = {
    val parser = new MyArgParser(args)
    inputPath = parser.getStringValue("inputPath", "")
    redisUgi = parser.getStringValue("redisUgi", ":").trim
    keyList = parser.getStringValue("keyList", "")
    mode = parser.getStringValue("mode", "")
    println(s"inputPath:$inputPath")
    println(s"redisUgi:$redisUgi")
    println(s"keyList:$keyList")
    println(s"mode:$mode")
  }

// 批量插入数据
  def putData(spark: SparkSession): Unit = {
    val rdd = spark.sparkContext
      .textFile(inputPath)
      .persist(StorageLevel.MEMORY_AND_DISK)

    println("count:" + rdd.count())
    println(rdd.take(5).mkString("\n"))

    val redisUgiBc = spark.sparkContext.broadcast(redisUgi)
    rdd.repartition(400).
      foreachPartition(iter => {
        val users = iter.map(x => {
          val item = x.split("\t")
          val userid = item(0)
          val userinfo = item(1)
          (userid, userinfo)
        })
        putDataPartition(users, redisUgiBc,  3600 * 24 * 2)
      })
  }

// 并行删除数据(这里目前还不是批量,因为无法使用Pipline的del方法)
  def delData(spark: SparkSession): Unit = {
    val rdd = spark.sparkContext
      .textFile(inputPath)
      .persist(StorageLevel.MEMORY_AND_DISK)

    println("count:" + rdd.count())
    println(rdd.take(5).mkString("\n"))

    val redisUgiBc = spark.sparkContext.broadcast(redisUgi)
    rdd.repartition(400).
      foreachPartition(iter => {
        val users = iter.map(x => {
          val item = x.split("\t")
          val userid = item(0)
          val userinfo = item(1)
          (userid, userinfo)
        })
        delDataPartition(users, redisUgiBc)
      })
  }

  /**
    * 将HDFS(inputPath)上面的内容读取后写入Redis
    *
    * @param dataIt     分批处理的Iterator
    * @param redisUgiBc redis认证信息
    * @param expireTime 数据的有效时间,单位为秒,默认2天
    */
  def putDataPartition(dataIt: Iterator[(String, String)], redisUgiBc: Broadcast[String],
                        expireTime: Int = 3600 * 24 * 2): Unit = {
    val jedisClient = GetRedisClient.getRedisClient(redisUgiBc.value)
    val dataList = dataIt.toArray
    val batchNum = 30
    val nStep = math.ceil(dataList.size / batchNum.toDouble).toInt

    for (index <- 0 to nStep) {
      val lowerIndex = batchNum * index
      val upperIndex =
        if (lowerIndex + batchNum >= dataList.size) dataList.size
        else batchNum * (index + 1)
      val batchData = dataList.slice(lowerIndex, upperIndex)
      var batchDataSize = 0
      val pipeline = jedisClient.pipelined()

      batchData.foreach(data => {
        val dataKey = data._1 
        val dataValue = data._2
        pipeline.hset(dataKey, "field_name", dataValue)
        pipeline.expire(dataKey, expireTime)
      })

      pipeline.sync()
    }
  }

  def delDataPartition(dataIt: Iterator[(String, String)], redisUgiBc: Broadcast[String]): Unit = {
    val jedisClient = GetRedisClient.getRedisClient(redisUgiBc.value)
    val dataList = dataIt.toArray
    val batchNum = 50
    val nStep = math.ceil(dataList.size / batchNum.toDouble).toInt

    for (index <- 0 to nStep) {
      val lowerIndex = batchNum * index
      val upperIndex =
        if (lowerIndex + batchNum >= dataList.size) dataList.size
        else batchNum * (index + 1)
      val batchData = dataList.slice(lowerIndex, upperIndex)

      batchData.foreach(data => {
        val dataKey = data._1 
        jedisClient.del(dataKey)
      })

    }
  }

  /**
    * 读取Redis测试
    */
  def getFromRedis(spark: SparkSession): Unit = {
    val keylist = ArrayBuffer()
    spark.sparkContext.textFile(inputPath)
      .take(100)
      .foreach(key => keylist.append(key))

    val jedisWrite = GetRedisClient.getRedisClient(redisUgi)
    val response: scala.collection.mutable.Map[String, Response[String]] = mutable.HashMap()
    val pipeline = jedisWrite.pipelined()
    keylist.foreach(key => {
      val value = pipeline.hget(key, "field_name")
      response.put(key, value)
    })
    val res = pipeline.syncAndReturnAll()  // 这里只是返回value,没有key的信息

    println("response:")
    response.take(120).foreach(x => println(x._1 + ":" + x._2.get()))
    println("res:")
    res.toArray().take(120).map(x => x.asInstanceOf[String]).foreach(println)
  }

}


参考

Redis Client API

你可能感兴趣的:(Spark)