kotlin实现内网穿透(可以放在安卓应用中使用)

相关内容

node.js实现内网穿透: https://www.jianshu.com/p/d2d4f8bff599
可以和node.js版混用
使用方式见node.js版,大同小异

实现代码

服务端:

import java.io.Closeable
import java.io.InputStream
import java.io.OutputStream
import java.net.InetAddress
import java.net.ServerSocket
import java.net.Socket
import java.security.KeyFactory
import java.security.interfaces.RSAPrivateKey
import java.util.*
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import javax.crypto.Cipher
import javax.crypto.EncryptedPrivateKeyInfo
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.PBEKeySpec
import kotlin.collections.HashMap

/**
 * 作者:yzh
 *
 * 创建时间:2022/7/16 14:47
 *
 * 描述:
 *
 * 修订历史:
 */
object NatServer {
    private const val privateKeyStr =
        "-----BEGIN ENCRYPTED PRIVATE KEY-----\nMIIC3TBXBgkqhkiG9w0BBQ0wSjApBgkqhkiG9w0BBQwwHAQIl1wFXFOitsACAggA\nMAwGCCqGSIb3DQIJBQAwHQYJYIZIAWUDBAEqBBA0oPbhhGLdPgtmSrVbZciIBIIC\ngArc9A6lpkDo6P+Mo75UfU5EzkJhRrR69V1+iwLodTiMYbJK5VVyO6FyqTBTvNJs\nfJG55VijUESxPcJ5I3hNwjZqhNzDXR58ZkNaMKcuIkgCR0Vt5bo8GSsx/4dLYppo\n/pvGsSQ9MYtCsGCZKy/dpwm7BgDQ2GiYWcNL382c16NT80Bt/qq6/oY0jRCp5l9q\npCg0+Mh3o1w/ozKc1HX/3zmVX9KmsJFmLfH8WcB49YJmFSNvu9hSBXypnqvGodQd\n8yNxuEN1/7AwSnBZ/LOYGzittNwxfZE+LJHakGF6MVyuPbJI0s82B8MCQvlBpTeB\n0qKXlfiXIhS7KYIXaMCt5kJPRJQpDv1dmQIhWRjgaqldHbv0E3INf/AObuAXjoN3\nBLi3TQCm4e1Cde0RP4JNCdiLTT/MJAgFSIf7WHRteS22qmF9BR9EhBPJlWJ2GqlJ\nAi8JrX16WT9lTWIMFAH4NDbpaIpn61fjtBR05XCse2ZN3HOTgVxk0nbwZpFtQ+cQ\nsl7cLTx+GFgg6nhamXGHuvSSSsWqtTHdhvmDOeP8rq0bWwz0zVUlSIazd0jBbnw7\nJHNYBTpp7OOqg6lPw3J4dTi6NvRkqJ9oCuQBwdzyaNPOmtkVRsTn4xy2L8H56G3u\nFupF+kO1BAQJEJi/lm5oqOXjtj+O6R9LtjoLwbVtxLvJMsU0/Q1qxYU4z397k8QN\ntmTGo5I6s6UKYYgZK5dSrhwTTPVheI13hZmL994H5zzmd+E8wMCQiiibRUs+qsHF\ncKmJ+lZSG0VLKMlGmvfac9o+mTv8+C7miu/mQq1akrVLGRt4GTHR4lgQDNi20YyP\nhzEWbzdsDegNCIQSLTPkIl0=\n-----END ENCRYPTED PRIVATE KEY-----\n"
    private val ALGORITHM = "RSA/ECB/OAEPPadding"
    private val aes256cbc = "PBEWithHmacSHA256AndAES_256"
    private val keyFactory by lazy { KeyFactory.getInstance("RSA") }
    private const val socketOutTime = 30 * 1000
    private const val contentLength = 128
    private val privateKey by lazy {
        val keyStr = privateKeyStr.replace("-----BEGIN ENCRYPTED PRIVATE KEY-----", "")
            .replace("-----END ENCRYPTED PRIVATE KEY-----", "")
        val keySpec =
            EncryptedPrivateKeyInfo(base64Decoder(keyStr)).run {
                val secretKey = SecretKeyFactory.getInstance(aes256cbc)
                    .generateSecret(PBEKeySpec("yzh".toCharArray()))
                getKeySpec(
                    Cipher.getInstance(aes256cbc)
                        .apply { init(Cipher.DECRYPT_MODE, secretKey, algParameters) })
            }
        keyFactory.generatePrivate(keySpec) as? RSAPrivateKey
    }
    private val passwords = arrayOf("yzh")
    private val executor by lazy {
        Executors.newCachedThreadPool {
            Thread(
                it,
                "NatServerSocket=="
            )
        }
    }
    private val checkFree by lazy {
        Executors.newSingleThreadScheduledExecutor {
            Thread(
                it,
                "CheckServerFree=="
            )
        }
    }

    fun run(dispatcherPort: Int) {
        val serverMap = HashMap()
        val synchronizedMap = HashMap()

        @Synchronized
        fun safeGet(key: String): Server? {
            return serverMap[key]?.run {
                if (natServerSocket.isClosed) {
                    null
                } else {
                    this
                }
            }
        }

        @Synchronized
        fun safeRemove(key: String) {
            serverMap.remove(key)
        }

        @Synchronized
        fun safeSet(key: String, ip: String?, port: Int): Server {
            return serverMap[key]?.run {
                if (natServerSocket.isClosed) {
                    null
                } else {
                    this
                }
            } ?: Server(ip, port) {
                safeRemove(key)
            }.also { serverMap[key] = it }
        }

        fun getSynKey(key: String): Any {
            synchronized(privateKeyStr) {
                return synchronizedMap[key] ?: Any().also { synchronizedMap[key] = it }
            }
        }

        val dispatcherSocket = ServerSocket(dispatcherPort)
        while (!dispatcherSocket.isClosed) {
            val natSocket: Socket
            try {
                natSocket = dispatcherSocket.accept().apply { soTimeout = socketOutTime }
            } catch (e: Exception) {
                println("dispatcher server error $e")
                return
            }
            NatServer(natSocket).run { ip, port ->
                val address = "$ip:$port"
                //当地址相同时没必要多次验重
                synchronized(getSynKey(address)) {
                    //验重需要时间,降低锁的粒度,不需要验重的优先通过
                    safeGet(address) ?: kotlin.run {
                        val natIps = "-${serverMap.keys.joinToString("-")}-"
                        val reg = if (ip.isNullOrBlank()) {
                            ":$port-"
                        } else {
                            "-:$port-"
                        }
                        if (natIps.indexOf(reg) == -1) {
                            safeSet(address, ip, port)
                        } else {
                            println("$address===$natIps")
                            null
                        }
                    }
                }
            }
        }
        println("dispatcher server closed")
    }

    class Server(ip: String?, port: Int, var freeCallBack: (() -> Unit)?) {
        val natServerSocket by lazy {
            if (ip.isNullOrBlank()) {
                ServerSocket(port)
            } else {
                ServerSocket(port, 50, InetAddress.getByName(ip))
            }
        }
        private val natServers = arrayListOf()
        private val paddingSocket = arrayListOf()
        private val scheduledFuture = checkFree.scheduleAtFixedRate({
            checkDied()
        }, 60, 60, TimeUnit.SECONDS)

        @Volatile
        private var emptyCount = 0

        init {
            Thread {
                while (!natServerSocket.isClosed) {
                    val clientSocket = try {
                        natServerSocket.accept().apply { soTimeout = socketOutTime }
                    } catch (e: Exception) {
                        println("accept error $e")
                        break
                    }
                    dispatcherNatServer(clientSocket)
                }
                free()
                destroy()
            }.start()
        }

        private fun free() {
            freeCallBack?.run {
                freeCallBack = null
                invoke()
                scheduledFuture.cancel(false)
            }
        }

        @Synchronized
        private fun checkDied() {
            if (natServers.isEmpty()) {
                emptyCount++
                if (emptyCount >= 5) {
                    natServerSocket.safeClose()
                }
            } else {
                emptyCount = 0
            }
        }

        @Synchronized
        private fun dispatcherNatServer(clientSocket: Socket) {
            if (natServers.isEmpty()) {
                paddingSocket.add(clientSocket)
            } else {
                natServers.removeAt(0).startNat(clientSocket)
            }
        }

        @Synchronized
        fun addNatServer(natServer: NatServer) {
            if (natServerSocket.isClosed) {
                return
            }
            emptyCount = 0
            if (paddingSocket.isEmpty()) {
                natServers.add(natServer)
            } else {
                natServer.startNat(paddingSocket.removeAt(0))
            }
        }

        @Synchronized
        fun removeNatServer(natServer: NatServer) {
            natServers.remove(natServer)
        }

        @Synchronized
        private fun removePaddingSocket(clientSocket: Socket) {
            paddingSocket.remove(clientSocket)
        }

        @Synchronized
        fun destroy() {
            natServers.forEach { it.natSocket.safeClose() }
            paddingSocket.forEach { it.safeClose() }
            natServers.clear()
            paddingSocket.clear()
        }
    }

    class NatServer(val natSocket: Socket) {
        private val natInStream by lazy { natSocket.getInputStream() }
        private val natOutStream by lazy { natSocket.getOutputStream() }

        @Volatile
        private var userSocket: Socket? = null

        @Volatile
        private var serverCallBack: Server? = null

        fun run(getServerCallBack: (ip: String?, port: Int) -> Server?) {
            executor.execute {
                val natInfoArray = ByteArray(contentLength)
                var readCount = 0
                val dataArray = ByteArray(contentLength)
                var length: Int
                while (!natSocket.isClosed) {
                    try {
                        length = natInStream.read(dataArray)
                    } catch (e: Exception) {
                        exception("natInStream error $e")
                        return@execute
                    }
                    if (length == -1) {
                        exception("natInStream closed")
                        return@execute
                    }
                    if (readCount != contentLength) {
                        if (readCount + length > contentLength) {
                            System.arraycopy(
                                dataArray,
                                0,
                                natInfoArray,
                                readCount,
                                contentLength - readCount
                            )
                            readCount = contentLength
                            length = length + readCount - contentLength
                            val newData = dataArray.copyOfRange(contentLength - readCount, length)
                            System.arraycopy(newData, 0, contentLength, 0, length)
                        } else {
                            System.arraycopy(dataArray, 0, natInfoArray, readCount, length)
                            readCount += length
                            length = 0
                        }
                        if (readCount != contentLength) {
                            continue
                        }
                        val info: String?
                        try {
                            info = decrypt(natInfoArray)
                            println("info: $info")
                        } catch (e: Exception) {
                            error("privateDecrypt error $e")
                            return@execute
                        }
                        if (info == null) {
                            error("info is null")
                            return@execute
                        }
                        val infos = info.split("-").map { it.trim() }
                        if (infos.size != 2) {
                            error("infos error ${infos.size}")
                            return@execute
                        }
                        if (passwords.indexOf(infos[1]) == -1) {
                            error("passwords error ${info[1]}")
                            return@execute
                        }
                        val address = infos[0].split(":").map { it.trim() }
                        val port: Int
                        try {
                            port = address[1].toInt()
                        } catch (e: Exception) {
                            println("port error $e")
                            return@execute
                        }
                        if (port < 0 || port > 65535) {
                            error("port2 error $port")
                            return@execute
                        }
                        serverCallBack = getServerCallBack(address[0], port)
                        if (serverCallBack == null) {
                            error("port3 error used")
                            return@execute
                        } else {
                            serverCallBack?.addNatServer(this)
                        }
                    }
                    if (length != 0) {
                        dataArray.indexOf(0).let {
                            if (it == -1 || it >= length) {
                                userSocket ?: try {
                                    natOutStream.write(ByteArray(1) { 0 })
                                } catch (e: Exception) {
                                    exception("natOutStream error $e")
                                    return@execute
                                }
                            } else {
                                userSocket?.run {
                                    val userOutputStream = getOutputStream()
                                    try {
                                        userOutputStream.write(dataArray, it + 1, length - it - 1)
                                    } catch (e: Exception) {
                                        exception("userOutputStream error $e")
                                        return@execute
                                    }
                                    val natDataSwitch =
                                        Switch(natSocket, this, natInStream, userOutputStream)
                                    Thread(natDataSwitch, "natServer端Data==").start()
                                }
                                return@execute
                            }
                        }
                    }
                }
            }
        }

        fun startNat(userSocket: Socket) {
            executor.execute {
                this.userSocket = userSocket
                try {
                    natOutStream.write(ByteArray(1) { 1 })
                } catch (e: Exception) {
                    exception("natOutStream error $e")
                    return@execute
                }
                val userDataSwitch =
                    Switch(userSocket, natSocket, userSocket.getInputStream(), natOutStream)
                Thread(userDataSwitch, "user端Data==").start()
            }
        }

        private fun error(message: String) {
            try {
                natOutStream.write(ByteArray(1) { 2 })
            } catch (e: Exception) {
                exception("natOutStream error $e")
            }
            exception(message)
        }

        private fun exception(message: String) {
            println(message)
            natSocket.safeClose()
            userSocket?.safeClose()
            serverCallBack?.removeNatServer(this)
        }
    }

    class Switch(
        private val inSocket: Socket, private val outSocket: Socket,
        private val inStream: InputStream, private val outStream: OutputStream
    ) : Runnable {
        private val buffer = ByteArray(1024 * 8)

        override fun run() {
            var length = 0
            while (!inSocket.isClosed && !outSocket.isClosed &&
                try {
                    inStream.read(buffer).also { length = it } > -1
                } catch (e: Exception) {
                    false
                }
            ) {

                try {
                    outStream.write(buffer, 0, length)
                    outStream.flush()
                } catch (e: Exception) {
                    break
                }
            }

            inSocket.safeClose()
            outSocket.safeClose()
        }
    }

    private fun Closeable?.safeClose() {
        try {
            this?.close()
        } catch (e: Exception) {
            e.printStackTrace()
        }
    }

    private fun base64Decoder(keyStr: String) = Base64.getDecoder().decode(
        keyStr.replace("\n", "")
            .replace(" ", "")
            .toByteArray(Charsets.UTF_8)
    )

    fun decrypt(data: ByteArray): String? {
        privateKey ?: return null
        return Cipher.getInstance(ALGORITHM).run {
            init(Cipher.DECRYPT_MODE, privateKey)
            String(doFinal(data), Charsets.UTF_8)
        }
    }
}

客户端:

import android.util.Base64
import java.io.Closeable
import java.io.InputStream
import java.io.OutputStream
import java.net.InetSocketAddress
import java.net.Proxy
import java.net.Socket
import java.security.KeyFactory
import java.security.interfaces.RSAPublicKey
import java.security.spec.X509EncodedKeySpec
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import javax.crypto.Cipher
import kotlin.system.exitProcess


/**
 * 作者:yzh
 *
 * 创建时间:2022/7/15 21:56
 *
 * 描述:
 *
 * 修订历史:
 */
object NatClient {
    private const val publicKeyStr =
        "-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCWAX9+7stLFV8sW2zA470M8b/5\nHt1FgkpGIVHfHvjIxh3k/APVfWlXpoN6lKIDQ/z4LZc+m03faeR/qjgl562W0sHQ\nDezv/cd84Uc2hDh/vTifL6RfNA7mrW3aqiVxT4gzvp327nzck/J/mzfVFyEgFb+z\nWsvr0xMkg+NNXMww8wIDAQAB\n-----END PUBLIC KEY-----\n"

    //RSA/ECB/OAEPPadding=RSA_PKCS1_OAEP_PADDING(node默认), RSA/ECB/PKCS1Padding=RSA_PKCS1_PADDING(java RSA默认)
    private val ALGORITHM = "RSA/ECB/OAEPPadding"
    private val keyFactory by lazy { KeyFactory.getInstance("RSA") }

    private val publicKey by lazy {
        val keyStr = publicKeyStr.replace("-----BEGIN PUBLIC KEY-----", "")
            .replace("-----END PUBLIC KEY-----", "")
        keyFactory.generatePublic(X509EncodedKeySpec(base64Decoder(keyStr))) as? RSAPublicKey
    }
    private const val socketOutTime = 30 * 1000
    private const val rateInterval = 3 * 1000L
    private val threadPool by lazy {
        Executors.newScheduledThreadPool(3) {
            Thread(it, "NatClientSocket==")
        }
    }

    fun run() {
        val errorCounts = HashMap()
        fun start(clientId: String) {
            NatImpClient(
                "127.0.0.1", 8080, ":9001",
                "192.168.2.6", 8989, "yzh", clientId
            ) {
                if (it.isNullOrBlank()) {
                    threadPool.execute { start(clientId) }
                } else {
                    errorCounts[clientId] = (errorCounts[clientId] ?: 0) + 1
                    println("error=$clientId=${errorCounts[clientId]}=$it")
                    var scheduledFuture: ScheduledFuture<*>? = null
                    scheduledFuture = threadPool.scheduleWithFixedDelay({
                        start(clientId)
                        scheduledFuture?.cancel(false)
                    }, (errorCounts[clientId] ?: 1) * 1000L, Long.MAX_VALUE, TimeUnit.MILLISECONDS)
                }
            }.run(rateInterval)
        }

        var count = 0
        var scheduledFuture: ScheduledFuture<*>? = null
        scheduledFuture = threadPool.scheduleWithFixedDelay({
            start(count.toString())
            if (count++ == 2) {
                scheduledFuture?.cancel(false)
            }
        }, 0, 1000, TimeUnit.MILLISECONDS)
    }

    class NatImpClient(
        localServerIp: String, localServerPort: Int, natServerUseAddr: String?,
        natDispatchIp: String, natDispatchPort: Int, password: String,
        private val clientId: String, private var usedCallBack: ((errorMessage: String?) -> Unit)?
    ) {
        private val natInfo by lazy { encrypt("$natServerUseAddr-$password") }
        private val natSocket by lazy {
            val proxy = Proxy(Proxy.Type.SOCKS, InetSocketAddress("127.0.0.1", 1080))
            var cSocket = Socket(proxy)
            cSocket = Socket()
            cSocket.apply {
                soTimeout = socketOutTime
                try {
                    connect(InetSocketAddress(natDispatchIp, natDispatchPort))
                } catch (e: Exception) {
                    usedCallBack?.run {
                        usedCallBack = null
                        invoke(e.message)
                    }
                    safeClose()
                }
                println("$clientId-connectNat地址:${remoteSocketAddress}")
            }
        }

        private val natInStream by lazy { natSocket.getInputStream() }
        private val natOutStream by lazy { natSocket.getOutputStream() }
        private val localServerSocket by lazy {
            val proxy = Proxy(Proxy.Type.SOCKS, InetSocketAddress("127.0.0.1", 1080))
            var socket = Socket(proxy)
            socket = Socket()
            socket.apply {
                soTimeout = socketOutTime
                try {
                    connect(InetSocketAddress(localServerIp, localServerPort))
                } catch (e: Exception) {
                    usedCallBack?.run {
                        usedCallBack = null
                        invoke(e.message)
                    }
                    safeClose()
                    natSocket.safeClose()
                }
                println("$clientId-connectLocalServer地址:${remoteSocketAddress}")
            }
        }
        private val executor by lazy {
            Executors.newSingleThreadScheduledExecutor {
                Thread(it, "NatClientSocket==")
            }
        }

        fun run(rate: Long = 3000) {
            natInfo?.let {
                executor.execute {
                    writeMockData(it)
                }
            } ?: return
            val scheduledFuture = executor.scheduleWithFixedDelay({
                writeMockData(ByteArray(1) { 1 })
            }, rate, rate, TimeUnit.MILLISECONDS)
            threadPool.execute {
                var read: Int
                while (!natSocket.isClosed) {
                    try {
                        read = natInStream.read()
                    } catch (e: Exception) {
                        usedCallBack?.run {
                            usedCallBack = null
                            invoke(e.message)
                            natSocket.safeClose()
                        }
                        return@execute
                    }
                    when (read) {
                        1 -> {
                            scheduledFuture.cancel(false)
                            usedCallBack?.run {
                                usedCallBack = null
                                invoke(null)
                            }
                            switchData(localServerSocket)
                            return@execute
                        }
                        2 -> {
                            exitProcess(0)
                        }
                        else -> {
//                            println("$clientId-pong")
                        }
                    }
                }
            }
        }

        private fun switchData(socket: Socket) {
            if (socket.isClosed) return
            executor.execute {
                if (!writeMockData(ByteArray(1) { 0 })) {
                    socket.safeClose()
                }
                val localServerDataSwitch = Switch(
                    socket, natSocket, socket.getInputStream(), natOutStream
                )
                Thread(localServerDataSwitch, "$clientId-localServer端Data==").start()
            }
            val natDataSwitch = Switch(
                natSocket, socket, natInStream, socket.getOutputStream()
            )
            Thread(natDataSwitch, "$clientId-natClient端Data==").start()
        }

        private fun writeMockData(data: ByteArray): Boolean {
            return try {
                natOutStream.write(data)
                true
            } catch (e: Exception) {
                usedCallBack?.run {
                    usedCallBack = null
                    invoke(e.message)
                }
                natSocket.safeClose()
                false
            }
        }
    }

    class Switch(
        private val inSocket: Socket, private val outSocket: Socket,
        private val inStream: InputStream, private val outStream: OutputStream
    ) : Runnable {
        private val buffer = ByteArray(1024 * 8)

        override fun run() {
            var length = 0
            while (!inSocket.isClosed && !outSocket.isClosed &&
                try {
                    inStream.read(buffer).also { length = it } > -1
                } catch (e: Exception) {
                    false
                }
            ) {
                try {
                    outStream.write(buffer, 0, length)
                    outStream.flush()
                } catch (e: Exception) {
                    break
                }
            }

            inSocket.safeClose()
            outSocket.safeClose()
        }

    }

    private fun Closeable?.safeClose() {
        try {
            this?.close()
        } catch (e: Exception) {
            e.printStackTrace()
        }
    }

    private fun base64Decoder(keyStr: String) = Base64.decode(
        keyStr.replace("\n", "")
            .replace(" ", "")
            .toByteArray(Charsets.UTF_8), Base64.DEFAULT
    )

    private fun encrypt(message: String): ByteArray? {
        return Cipher.getInstance(ALGORITHM).run {
            init(Cipher.ENCRYPT_MODE, publicKey)
            doFinal(message.toByteArray(Charsets.UTF_8))
        }
    }
}

比较遗憾的是在安卓端解密时报了个 PBEWithHmacSHA256AndAES_256 SecretKeyFactory not available 可以通过采用非加密的私钥来解决这个问题


    private val privateKey by lazy {
        privateKey ?: return@lazy null
        val keySpec = if (pwd.isNullOrBlank()) {
            val keyStr = privateKey.replace("-----BEGIN PRIVATE KEY-----", "")
                .replace("-----END PRIVATE KEY-----", "")
            PKCS8EncodedKeySpec(base64Decoder(keyStr))
        } else {
            val keyStr = privateKey.replace("-----BEGIN ENCRYPTED PRIVATE KEY-----", "")
                .replace("-----END ENCRYPTED PRIVATE KEY-----", "")
            EncryptedPrivateKeyInfo(base64Decoder(keyStr)).run {
                val secretKey = SecretKeyFactory.getInstance(aes256cbc).generateSecret(PBEKeySpec(pwd.toCharArray()))
                getKeySpec(Cipher.getInstance(aes256cbc).apply { init(Cipher.DECRYPT_MODE, secretKey, algParameters) })
            }
        }
        keyFactory.generatePrivate(keySpec) as? RSAPrivateKey
    }

你可能感兴趣的:(kotlin实现内网穿透(可以放在安卓应用中使用))