基于MQ的2PC分布式事务

​​​​
在这里插入图片描述

上图阐释了如何基于mq实现2pc的分布式事务

  • 一阶段红线部分。
  • 二阶段为蓝线部分。

图中展示了较为复杂的调用方式,S1调用S2、S3,S3又调用了S4。
感谢seata开源社区大佬的帮助。虽然2pc本身存在很多问题,但是自己手动实现一遍还是学习到很多。
本文仅做参考,不具备生产意义。
seata社区陈建斌大佬指正的问题列表如下:
问题
第一:tm需要有事务记录表,来恢复事务,而且要考虑到rm没任何异常,只是因为tm宕机导致tm的二阶段提交没有入库,但是由于这样,rm本身应该提交的事务变成了回滚。
第二:需要把connection换为xaconnection,使用xa协议来保证rm宕机后事务数据可恢复。
第三:要保证消息队列中间件的高可用。
第四:要防止资源悬挂问题,因为没有了分支事务注册,很可能因为网络或者其它因素,先发后置了,导致了tm没感知到这个rm的存在,这个rm就可能因为用了xa协议导致死锁。

show your code

根据上图我们可以很好的实现代码如下:此处基于rocketmq方式实现。

引入以下包


            org.springframework.boot
            spring-boot-starter-aop
        
        
            org.apache.dubbo
            dubbo
            2.7.2
            provided
        
        
            org.apache.rocketmq
            rocketmq-spring-boot-starter
            2.1.1
        

全局事务注解此注解开启全局事务,真正事务还是交给Transactional注解去执行

package com.xxx.mq.trx.config;

import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @Description TODO
 * @Author 姚仲杰
 * @Date 2021/1/2 21:36
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface GlobalTransaction {

}

全局事务切面

package com.xxx.mq.trx.aspect;

import com.xxx.mq.trx.config.TransactionConst;
import com.xxx.mq.trx.core.TrxContextHolder;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.rocketmq.spring.core.RocketMQTemplate;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.StringUtils;

/**
 * @Description TODO
 * @Author 姚仲杰
 * @Date 2021/1/2 21:38
 */
public class GlobalTrxAspect {

    @Autowired
    RocketMQTemplate rocketMQTemplate;

    @Pointcut("@annotation(com.xxx.mq.trx.config.GlobalTransaction)")
    public void pointcut(){}

    @Around("pointcut()")
    public void around(ProceedingJoinPoint joinPoint) throws Throwable {
        //方法执行前需生成trx_id
        //判断是否事务发起者,如果能从线程上下文取到事务id说明是参与者,如果取不到则是事务管理者。
        String trx_id = TrxContextHolder.getTrxId();
        boolean isManager = false;
        if (StringUtils.isEmpty(trx_id)) {
            UUID uuid = UUID.randomUUID();
            TrxContextHolder.setTrxId(uuid.toString());
            isManager=true;
        }
        Map map=new HashMap(2);
        map.put(TransactionConst.TRX_ID,trx_id);
        try {
            joinPoint.proceed();
            map.put(trx_id, TransactionConst.COMMIT);
        } catch (Throwable throwable) {
            map.put(trx_id, TransactionConst.ROLLBACK);
            throw throwable;
        }finally {
            //方法执行后需发送消息告知所有事务参与者是提交还是回滚
            if(isManager) {
                Message msg = MessageBuilder.withPayload(map).build();
                rocketMQTemplate.send(TransactionConst.TRX_TOPIC, msg);
            }
        }
    }
}

事务常量定义

package com.xxx.mq.trx.config;

/**
 * @Description TODO
 * @Author 姚仲杰
 * @Date 2021/1/4 9:28
 */
public interface TransactionConst {
    int COMMIT=1;
    int ROLLBACK=0;
    String TRX_ID="trx_id";
    String TRX_TOPIC="global_trx_topic";
    String TRX_GROUP="global_trx_group";
}

package com.xxx.mq.trx.aspect;

import com.xxx.mq.trx.core.ConnectionProxy;
import com.xxx.mq.trx.core.TrxContextHolder;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

/**
 * @Description 拦截getConnection调用用于处理事务手动提交
 * @Author 姚仲杰
 * @Date 2021/01/04 11:46
 */
@Aspect
@Component
public class DataSourceAspect {
    @Autowired
    ObjectFactory bean;

    ReentrantLock lock = new ReentrantLock();

    @Around("execution(* javax.sql.DataSource.getConnection(..))")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        Connection conn = (Connection)joinPoint.proceed();
        String trxId=TrxContextHolder.getTrxId();
        if (StringUtils.isNotBlank(trxId)) {
            ConnectionProxy connectionProxy = bean.getObject();
            connectionProxy.setConnection(conn);
            lock.lock();
            try {
                List list = TrxContextHolder.getConnections(trxId);
                if (list == null) {
                    list = new ArrayList<>();
                }
                list.add(connectionProxy);
                TrxContextHolder.setConnections(trxId,list);
            } finally {
                lock.unlock();
            }
            return connectionProxy;
        }
        return conn;
    }

}

连接代理让Transactional注解的事务提交执行个寂寞,然后转交由我们自己mq通知提交。

package com.xxx.mq.trx.core;

import java.sql.Array;
import java.sql.Blob;
import java.sql.CallableStatement;
import java.sql.Clob;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.NClob;
import java.sql.PreparedStatement;
import java.sql.SQLClientInfoException;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.SQLXML;
import java.sql.Savepoint;
import java.sql.Statement;
import java.sql.Struct;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.context.annotation.Scope;
import org.springframework.stereotype.Component;

/**
 * @Description TODO
 * @Author 姚仲杰
 * @Date 2021/1/4 10:48
 */
@Component
@Scope(value = ConfigurableBeanFactory.SCOPE_PROTOTYPE)
public class ConnectionProxy implements Connection {
    private static final Logger LOGGER = LoggerFactory.getLogger(ConnectionProxy.class);

    private Connection connection;
    //mq收到事务通知之后调用此方法执行提交或回滚
    public void notify(int state) {
        try {
            if (state == 1) {
                connection.commit();
            } else {
                connection.rollback();
            }
            connection.close();
        } catch (Exception e) {
            LOGGER.error(e.getLocalizedMessage(), e);
        }
    }

    @Override
    public void setAutoCommit(boolean autoCommit) throws SQLException {
        connection.setAutoCommit(false);
    }

    @Override
    public void commit() throws SQLException {
        // connection.commit();
    }

    @Override
    public void rollback() throws SQLException {
        // connection.rollback();
    }

    @Override
    public void close() throws SQLException {
        // connection.close();
    }

    @Override
    public boolean getAutoCommit() throws SQLException {
        return connection.getAutoCommit();
    }

    @Override
    public Statement createStatement() throws SQLException {
        return connection.createStatement();
    }

    @Override
    public PreparedStatement prepareStatement(String sql) throws SQLException {
        return connection.prepareStatement(sql);
    }

    @Override
    public CallableStatement prepareCall(String sql) throws SQLException {
        return connection.prepareCall(sql);
    }

    @Override
    public String nativeSQL(String sql) throws SQLException {
        return connection.nativeSQL(sql);
    }

    @Override
    public boolean isClosed() throws SQLException {
        return connection.isClosed();
    }

    @Override
    public DatabaseMetaData getMetaData() throws SQLException {
        return connection.getMetaData();
    }

    @Override
    public void setReadOnly(boolean readOnly) throws SQLException {
        connection.setReadOnly(readOnly);
    }

    @Override
    public boolean isReadOnly() throws SQLException {
        return connection.isReadOnly();
    }

    @Override
    public void setCatalog(String catalog) throws SQLException {
        connection.setCatalog(catalog);
    }

    @Override
    public String getCatalog() throws SQLException {
        return connection.getCatalog();
    }

    @Override
    public void setTransactionIsolation(int level) throws SQLException {
        connection.setTransactionIsolation(level);
    }

    @Override
    public int getTransactionIsolation() throws SQLException {
        return connection.getTransactionIsolation();
    }

    @Override
    public SQLWarning getWarnings() throws SQLException {
        return connection.getWarnings();
    }

    @Override
    public void clearWarnings() throws SQLException {
        connection.clearWarnings();
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
        return connection.createStatement(resultSetType, resultSetConcurrency);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency)
        throws SQLException {
        return connection.prepareStatement(sql, resultSetType, resultSetConcurrency);
    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return connection.prepareCall(sql, resultSetType, resultSetConcurrency);
    }

    @Override
    public Map> getTypeMap() throws SQLException {
        return connection.getTypeMap();
    }

    @Override
    public void setTypeMap(Map> map) throws SQLException {
        connection.setTypeMap(map);
    }

    @Override
    public void setHoldability(int holdability) throws SQLException {
        connection.setHoldability(holdability);
    }

    @Override
    public int getHoldability() throws SQLException {
        return connection.getHoldability();
    }

    @Override
    public Savepoint setSavepoint() throws SQLException {
        return connection.setSavepoint();
    }

    @Override
    public Savepoint setSavepoint(String name) throws SQLException {
        return connection.setSavepoint(name);
    }

    @Override
    public void rollback(Savepoint savepoint) throws SQLException {
        connection.rollback(savepoint);
    }

    @Override
    public void releaseSavepoint(Savepoint savepoint) throws SQLException {
        connection.releaseSavepoint(savepoint);
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability)
        throws SQLException {
        return connection.createStatement(resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency,
        int resultSetHoldability) throws SQLException {
        return connection.prepareStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    @Override
    public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency,
        int resultSetHoldability) throws SQLException {
        return connection.prepareCall(sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
        return connection.prepareStatement(sql, autoGeneratedKeys);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
        return connection.prepareStatement(sql, columnIndexes);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
        return connection.prepareStatement(sql, columnNames);
    }

    @Override
    public Clob createClob() throws SQLException {
        return connection.createClob();
    }

    @Override
    public Blob createBlob() throws SQLException {
        return connection.createBlob();
    }

    @Override
    public NClob createNClob() throws SQLException {
        return connection.createNClob();
    }

    @Override
    public SQLXML createSQLXML() throws SQLException {
        return connection.createSQLXML();
    }

    @Override
    public boolean isValid(int timeout) throws SQLException {
        return connection.isValid(timeout);
    }

    @Override
    public void setClientInfo(String name, String value) throws SQLClientInfoException {
        connection.setClientInfo(name, value);
    }

    @Override
    public void setClientInfo(Properties properties) throws SQLClientInfoException {
        connection.setClientInfo(properties);
    }

    @Override
    public String getClientInfo(String name) throws SQLException {
        return connection.getClientInfo(name);
    }

    @Override
    public Properties getClientInfo() throws SQLException {
        return connection.getClientInfo();
    }

    @Override
    public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
        return connection.createArrayOf(typeName, elements);
    }

    @Override
    public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
        return connection.createStruct(typeName, attributes);
    }

    @Override
    public void setSchema(String schema) throws SQLException {
        connection.setSchema(schema);
    }

    @Override
    public String getSchema() throws SQLException {
        return connection.getSchema();
    }

    @Override
    public void abort(Executor executor) throws SQLException {
        connection.abort(executor);
    }

    @Override
    public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
        connection.setNetworkTimeout(executor, milliseconds);
    }

    @Override
    public int getNetworkTimeout() throws SQLException {
        return connection.getNetworkTimeout();
    }

    @Override
    public  T unwrap(Class iface) throws SQLException {
        return connection.unwrap(iface);
    }

    @Override
    public boolean isWrapperFor(Class iface) throws SQLException {
        return connection.isWrapperFor(iface);
    }

    public Connection getConnection() {
        return connection;
    }

    public void setConnection(Connection connection) {
        this.connection = connection;
    }
}

事务上下文

package com.xxx.mq.trx.core;

import java.util.HashMap;
import java.util.Map;

/**
 * @Description TODO
 * @Author 姚仲杰
 * @Date 2020/12/28 11:42
 */
public class TrxContext {

    private ThreadLocal> threadLocal=new ThreadLocal>(){
        @Override
        protected Map initialValue() {
            return new HashMap();
        }
    };

    public String put(String key, String value) {
        return threadLocal.get().put(key, value);
    }

    public String get(String key) {
        return threadLocal.get().get(key);
    }

    public String remove(String key) {
        return threadLocal.get().remove(key);
    }

    public Map entries() {
        return threadLocal.get();
    }
}

事务上下文持有者缓存了trxId,以及全局事务连接列表等属性。

package com.xxx.mq.trx.core;

import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
 * @Description 
 * @Author 姚仲杰
 * @Date 2020/12/28 11:46
 */
public class TrxContextHolder {
    private static final Logger LOGGER = LoggerFactory.getLogger(TrxContextHolder.class);

    public static final TrxContext TRX_CONTEXT_HOLDER=new TrxContext();

    private static volatile ConcurrentHashMap> connectionsMap =
        new ConcurrentHashMap<>();

    public static final String TRX_ID="TRX_ID";

    public static String getTrxId(){
        return TRX_CONTEXT_HOLDER.get(TRX_ID);
    }

    public static void setTrxId(String trxId){
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("set trx_id:[{}]", trxId);
        }
        TRX_CONTEXT_HOLDER.put(TRX_ID, trxId);

    }

    public static String removeTrxId() {
        String trxId = TRX_CONTEXT_HOLDER.remove(TRX_ID);
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("remove trx_id:[{}] ", trxId);
        }
        return trxId;
    }

    public static List getConnections(String trxId){
        if (StringUtils.isEmpty(trxId)){
            LOGGER.error("trx_id can not be empty");
            throw new IllegalArgumentException();
        }
        return connectionsMap.get(trxId);
    }

    public static void setConnections(String trxId,List connections){
        if (StringUtils.isEmpty(trxId)){
            LOGGER.error("trx_id can not be empty");
            throw new IllegalArgumentException();
        }
        if (CollectionUtils.isEmpty(connections)){
            LOGGER.error("connections can not be empty,require at least one connection");
            throw new IllegalArgumentException();
        }
        connectionsMap.put(trxId,connections);
    }

    public static void removeConnections(String trxId){
        if (StringUtils.isEmpty(trxId)){
            LOGGER.error("trx_id can not be empty");
            throw new IllegalArgumentException();
        }
        connectionsMap.remove(trxId);
    }
}

二阶段提交mq监听器

package com.xxx.mq.trx.core;

import com.alibaba.fastjson.JSON;
import com.xxx.mq.trx.config.TransactionConst;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.CollectionUtils;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
import org.apache.rocketmq.common.message.MessageExt;
import org.apache.rocketmq.spring.annotation.ConsumeMode;
import org.apache.rocketmq.spring.annotation.RocketMQMessageListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @Description TODO
 * @Author 姚仲杰
 * @Date 2021/1/4 11:18
 */
@RocketMQMessageListener(consumeMode = ConsumeMode.CONCURRENTLY,topic = TransactionConst.TRX_TOPIC,consumerGroup = TransactionConst.TRX_GROUP)
public class TransactionMassageListener implements MessageListenerConcurrently {
    public static final Logger LOGGER= LoggerFactory.getLogger(TransactionMassageListener.class);
    @Override
    public ConsumeConcurrentlyStatus consumeMessage(List msgs,
        ConsumeConcurrentlyContext context) {
        LOGGER.info("receive global transaction message: {}",msgs);
        MessageExt messageExt = msgs.get(0);
        //如果本地获取不到事务等待连接直接返回消费成功,因为这是广播模式。
        try {
            String s = new String(messageExt.getBody(), "utf-8");
            Map map = JSON.parseObject(s, HashMap.class);
            String trxId= (String) map.get(TransactionConst.TRX_ID);
            int state= (int) map.get(trxId);
            List connections = TrxContextHolder.getConnections(trxId);
            if (!CollectionUtils.isEmpty(connections)){
                try {
                    connections.forEach(cp -> cp.notify(state));
                }finally {
                    TrxContextHolder.removeConnections(trxId);
                }
            }
        }catch (Throwable e){
           return  ConsumeConcurrentlyStatus.RECONSUME_LATER;
        }
        return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
    }
}

dubbo事务id传播过滤器

package com.xxx.mq.trx.integration.dubbo;

import com.xxx.mq.trx.config.TransactionConst;
import com.xxx.mq.trx.core.TrxContextHolder;
import org.apache.dubbo.common.extension.Activate;
import org.apache.dubbo.rpc.Filter;
import org.apache.dubbo.rpc.Invocation;
import org.apache.dubbo.rpc.Invoker;
import org.apache.dubbo.rpc.Result;
import org.apache.dubbo.rpc.RpcContext;
import org.apache.dubbo.rpc.RpcException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @Description 用户传递trx_id给下游服务,并将事务id绑定给本地线程变量
 * @Author 姚仲杰
 * @Date 2021/01/04 11:46
 */
@Activate(group = {"provider", "consumer"}, order = 100)
public class DubboTrxPropagationFilter implements Filter {

    private static final Logger LOGGER = LoggerFactory.getLogger(DubboTrxPropagationFilter.class);

    @Override
    public Result invoke(Invoker invoker, Invocation invocation) throws RpcException {
        String trxId = TrxContextHolder.getTrxId();
        String rpcXid = RpcContext.getContext().getAttachment(TransactionConst.TRX_ID);
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("trxId in TrxContext[{}] trxId in RpcContext[{}]", trxId, rpcXid);
        }
        boolean bind = false;
        if (trxId != null) {
            RpcContext.getContext().setAttachment(TransactionConst.TRX_ID, trxId);
        } else {
            if (rpcXid != null) {
                TrxContextHolder.setTrxId(rpcXid);
                bind = true;
            }
        }
        try {
            return invoker.invoke(invocation);
        } finally {
            if (bind) {
                TrxContextHolder.removeTrxId();
            }
        }
    }

}

你可能感兴趣的:(基于MQ的2PC分布式事务)