分布式事务框架底层原理(2PC)

事务中的全部操作,要求要么都成功,要么都不成功。通常在同一个jvm中是比较容易做到的,例如数据库JDBC操作,Spring能够帮我们做这件事。但是在分布式环境下,A服务调用B服务,在这个过程出现了异常,又该怎么保证A、B服务的事务都回滚呢?

分布式事务一般有三种解决方案:
1、2PC
2、最终消息一致性
3、TCC

这里我们介绍2PC这种解决方案。
分布式事务框架底层原理(2PC)_第1张图片
在A调用B的过程中,A出现了一个异常。

A服务:

@Service
public class DemoService {

    @Autowired
    private DemoDao demoDao;
    
    @SxmTransactional(start = true)
    public void test() {
        demoDao.insert("server1");
        HttpUtil.post("http://localhost:8082/server2/test");
        int i = 1/0;
    }
}

B服务:

@Service
public class DemoService {

    @Autowired
    private DemoDao demoDao;

    @SxmTransactional(end = true)
    public void test() {
        System.out.println("执行server2业务");
        demoDao.insert("server2");
    }
}

自定义分布式事务注解:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Transactional
public @interface SxmTransactional {
    boolean start() default false;
    boolean end() default false;
}

我们需要在A服务执行方法之前,提前做些准备,这里借助spring的AOP来完成。

//保证当前切面先ConnectionAspect执行,值越小有越高的优先级
@Order(10000)
@Aspect
@Component
public class TransactionAspect {
    @Around("@annotation(com.su.annotation.SxmTransactional)")
    public void invoke(ProceedingJoinPoint joinPoint){
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        SxmTransactional annotation = method.getAnnotation(SxmTransactional.class);
        String gid = "";
        if (annotation.start()){
            //创建事务组
            gid = TransactionManager.createGroup();
        }else {
            //从上一个服务调用传过来
            gid = TransactionManager.getCurrentGroup();
        }
        //创建本地事务
        Transaction transaction = TransactionManager.createTransaction(gid);
        try {
            //执行目标方法(在这里我们要拿到真正数据库操作连接的控制权)
            //注意:这里要保证目标方法快速执行完,不能形成等待死锁,所以在目标方法中新建一个线程来执行commit或rollback
            joinPoint.proceed();
            //向netty服务端事务管理器提交本地事务 commit
            TransactionManager.commitTransaction(transaction,annotation.end(),TransactionType.COMMIT);
        } catch (Throwable throwable) {
            //向netty服务端事务管理器提交本地事务 rollback
            TransactionManager.commitTransaction(transaction,annotation.end(),TransactionType.ROLLBACK);
            throwable.printStackTrace();
        }
    }
}

在执行目标方法时,涉及到了jdbc操作,所以我们要拿到数据库操作连接的控制权,这里同样使用AOP来完成。

@Aspect
@Component
public class ConnectionAspect {
    @Around("execution(* javax.sql.DataSource.getConnection(..))")
    public Connection getConnection(ProceedingJoinPoint joinPoint){
        try {
            //真正的jdbc连接对象
            Connection connection = (Connection) joinPoint.proceed();
            //从ThreadLocal中获取本地当前Transaction对象,所以要让TransactionAspect先执行
            Transaction currentTransaction = TransactionManager.getCurrentTransaction();
            return new SxmConnection(connection,currentTransaction);
        } catch (Throwable throwable) {
            throwable.printStackTrace();
        }
        return null;
    }
}

返回包装后的连接对象:

public class SxmConnection implements Connection {

    //真正的数据库连接对象
    private Connection connection;

    //当前连接的本地自定义事务对象
    private Transaction transaction;

    public SxmConnection(Connection connection,Transaction transaction) {
        this.connection = connection;
        this.transaction = transaction;
    }

    @Override
    public void commit() throws SQLException {
        new Thread(()->{
            //需要等待netty服务端事务管理器通知,然后才提交
            transaction.await();
            //netty服务端可能会更改transaction的TransactionType
            try {
                if (transaction.getType().equals(TransactionType.COMMIT)){
                    connection.commit();
                }else {
                    connection.rollback();
                }
                connection.close();
            }catch (Exception e){
                e.printStackTrace();
            }
        }).start();
    }

    @Override
    public void rollback() throws SQLException {
        new Thread(()->{
            //需要等待netty服务事务管理器通知,然后才回滚
            transaction.await();
            try {
                connection.rollback();
                connection.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }).start();
    }

    //注意这里不能调用connection.close(),如果真正的连接对象都关闭了,后面都不用玩了。
    @Override
    public void close() throws SQLException {
    }

    @Override
    public Statement createStatement() throws SQLException {
        return connection.createStatement();
    }

    @Override
    public PreparedStatement prepareStatement(String sql) throws SQLException {
        return connection.prepareStatement(sql);
    }
    //以下还有许多Override方法
    ......
}

本地事务管理器:

@Component("localTransactionManager")
public class TransactionManager {

    //用于保存本地事务管理器管理的所有事务组,key为groupId
    public static Map<String,Map<String,Transaction>> groupMap = new HashMap();

    //用于保存本地事务组groupId
    public static ThreadLocal<String> groupThreadLocal = new ThreadLocal<>();

    //用于保存本地事务
    public static ThreadLocal<Transaction> transactionThreadLocal  =new ThreadLocal<>();

    //用于记录当前事务标号
    public static ThreadLocal<Integer> currentTransactionNum = new ThreadLocal<>();

    public static NettyClient client;

    @Autowired
    public void setNettyClient(NettyClient nettyClient){
        client = nettyClient;
    }

    /**
     * 创建事务组
     */
    public static String createGroup(){
        String gid = UUID.randomUUID().toString();
        groupMap.put(gid,new HashMap<>());
        groupThreadLocal.set(gid);
        //发送创建事务组消息给netty服务器
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("groupId",gid);
        jsonObject.put("command","create");
        client.send(jsonObject);
        return gid;
    }
    /**
     * 创建本地事务
     * @param gid
     */
    public static Transaction createTransaction(String gid){
        Transaction transaction = new Transaction();
        String tid = UUID.randomUUID().toString();
        transaction.setId(tid);
        transaction.setGid(gid);
        if (groupMap.get(gid)==null){
            groupMap.put(gid,new HashMap<>());
        }
        //保存到本地事务管理器
        transactionThreadLocal.set(transaction);
        groupMap.get(gid).put(tid,transaction);
        //保存本地事务标号
        saveCurrentTransactionNum();
        return transaction;
    }

    /**
     * 保存本地事务标号
     * 上一个服务的基础上加1
     */
    private static void saveCurrentTransactionNum() {
        Integer num = currentTransactionNum.get()==null?0:currentTransactionNum.get()+1;
        currentTransactionNum.set(num);
    }

    /**
     * 得到本地事务标号
     */
    public static Integer getCurrentTransactionNum(){
        return currentTransactionNum.get();
    }

    /**
     * 设置本地事务标号
     */
    public static void setCurrentTransactionNum(Integer num){
        currentTransactionNum.set(num);
    }

    /**
     * 设置当前线程维护的事务组
     * @param gid
     */
    public static void setCurrentGroup(String gid){
        groupThreadLocal.set(gid);
    }

    /**
     * 得到当前线程维护的事务组
     * @return
     */
    public static String getCurrentGroup(){
        return groupThreadLocal.get();
    }


    /**
     * 得到当前线程维护的事务
     */
    public static Transaction getCurrentTransaction(){
        return transactionThreadLocal.get();
    }


    /**
     * 向netty服务器提交事务
     */
    public static void commitTransaction(Transaction transaction,boolean end,TransactionType type){
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("groupId",transaction.getGid());
        jsonObject.put("transactionId",transaction.getId());
        jsonObject.put("command","add");
        jsonObject.put("end",end);
        jsonObject.put("transactionType",type);
        jsonObject.put("transactionNum",getCurrentTransactionNum());
        client.send(jsonObject);
    }
}

本地事务:

public class Transaction {
    //事务id
    private String id;
    //事务组id
    private String gid;
    //事务类型(commit or rollback)
    private TransactionType type;
    //当前事务对应的锁对象
    private Lock lock = new ReentrantLock();
    private Condition condition = lock.newCondition();

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getGid() {
        return gid;
    }

    public void setGid(String gid) {
        this.gid = gid;
    }

    public TransactionType getType() {
        return type;
    }

    public void setType(TransactionType type) {
        this.type = type;
    }

    public Condition getCondition() {
        return condition;
    }

    //等待(当前线程会等待在这个condition对象的等待队列中)
    public void await(){
        try {
            lock.lock();
            condition.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.unlock();
        }
    }
    //唤醒
    public void signal(){
        try {
            lock.lock();
            condition.signal();
        }finally {
            lock.unlock();
        }
    }
}

事务类型:

public enum TransactionType {
    COMMIT,ROLLBACK
}

本地netty客户端:

@Component
public class NettyClient implements InitializingBean {

    public NettyClientHandler client = null;

    private static ExecutorService executorService = Executors.newCachedThreadPool();

    @Override
    public void afterPropertiesSet() throws Exception {
        start("localhost", 8080);
    }

    public void start(String hostName, Integer port) {
        client = new NettyClientHandler();

        Bootstrap b = new Bootstrap();
        EventLoopGroup group = new NioEventLoopGroup();
        b.group(group)
                .channel(NioSocketChannel.class)
                .option(ChannelOption.TCP_NODELAY, true)
                .handler(new ChannelInitializer<SocketChannel>() {
                    protected void initChannel(SocketChannel socketChannel) throws Exception {
                        ChannelPipeline pipeline = socketChannel.pipeline();
                        pipeline.addLast("decoder", new StringDecoder());
                        pipeline.addLast("encoder", new StringEncoder());
                        pipeline.addLast("handler", client);
                    }
                });

        try {
            b.connect(hostName, port).sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    public void send(JSONObject jsonObject) {
        try {
            client.call(jsonObject);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
public class NettyClientHandler extends ChannelInboundHandlerAdapter {

    private ChannelHandlerContext context;

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        context = ctx;
    }

    /**
     * 接收服务端通知
     */
    @Override
    public synchronized void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        JSONObject jo = JSON.parseObject((String) msg);
        String groupId = jo.getString("groupId");
        String noticeCommand = jo.getString("noticeCommand");
        String transactionId = jo.getString("transactionId");
        System.out.println("client receive command:"+noticeCommand);
        Transaction transaction = TransactionManager.groupMap.get(groupId).get(transactionId);
        if ("commit".equals(noticeCommand)) {
            transaction.setType(TransactionType.COMMIT);
        }else {
            transaction.setType(TransactionType.ROLLBACK);
        }
        transaction.signal();
    }

    public synchronized Object call(JSONObject data) throws Exception {
        context.writeAndFlush(data.toJSONString()).channel().newPromise();
        return null;
    }
}

HttpUtil:

@Component
public class HttpUtil {

    private static RestTemplate restTemplate = new RestTemplate();

    public  static Object post(String url){
        HttpHeaders header = new HttpHeaders();
        header.set("groupId", TransactionManager.getCurrentGroup());
        header.set("transactionNum",String.valueOf(TransactionManager.getCurrentTransactionNum()));
        HttpEntity<MultiValueMap<String, String>> httpEntity = new HttpEntity<>(null, header);
        return  restTemplate.postForObject(url,httpEntity,Object.class);
    }
}

请求拦截器:

@Configuration
public class WebAppConfig extends WebMvcConfigurerAdapter {

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new RequestInterceptor());
    }
}
public class RequestInterceptor implements HandlerInterceptor {

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        //接收从上一个服务调用传过来的
        String groupId = request.getHeader("groupId");
        String transactionNum = request.getHeader("transactionNum");
        TransactionManager.setCurrentGroup(groupId);
        TransactionManager.setCurrentTransactionNum(Integer.valueOf(transactionNum==null? "0":transactionNum));
        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }
}

服务端事务管理器:

/**
 * 作为分布式事务管理器,它需要:
 * 1. 创建并保存事务组
 * 2. 保存各个子事务在对应的事务组内
 * 3. 统计并判断事务组内的各个子事务状态,以算出当前事务组的状态(提交or回滚)
 * 4. 通知各个子事务提交或回滚
 */
public class NettyServerHandler extends ChannelInboundHandlerAdapter {

    //private static ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);

    //保存事务组对应的所有channel
    private static Map<String,Map<String,Channel>> channelGroupMap = new ConcurrentHashMap<>();
    //保存事务组内所有的事务
    private static Map<String,List<JSONObject>> groupTransactions = new ConcurrentHashMap<>();
    //保存事务组内所有事务的状态
    private static Map<String,List<String>> groupStatus = new ConcurrentHashMap<>();
    //保存每个事务组结束状态
    private static Map<String,Boolean> endGroupMap = new ConcurrentHashMap<>();
    //保存每个事务组应该有的事务数量
    private static Map<String,Integer> countTransGroupMap = new ConcurrentHashMap<>();

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        System.out.println("接受数据:" + msg.toString());
        JSONObject jsonObject = JSON.parseObject((String) msg);
        String groupId = jsonObject.getString("groupId");
        String command = jsonObject.getString("command");
        System.out.println("groupId:"+groupId);
        System.out.println("command:"+command);
        if ("create".equals(command)){
            createGroup(groupId);
        }else if ("add".equals(command)){
            boolean end = jsonObject.getBoolean("end");
            String transactionId = jsonObject.getString("transactionId");
            String transactionType = jsonObject.getString("transactionType");
            Integer transactionNum = jsonObject.getInteger("transactionNum");
            //保存事务组对应的channel
            addChannelMap(groupId,transactionId,ctx.channel());
            //保存事务组中的事务
            addGroupTransactions(groupId,jsonObject);
            addGroupStatus(groupId,transactionType);
            if (end){
                System.out.println("------已经执行-----");
                endGroupMap.put(groupId,Boolean.TRUE);
                countTransGroupMap.put(groupId,transactionNum);
            }
            //如果当前事务组收到end,并且事务组应有事务数量等于实际接收到事务数量,触发计算事务组状态
            if (endGroupMap.get(groupId) && countTransGroupMap.get(groupId)==groupTransactions.get(groupId).size()){
                //算出当前事务组的状态(提交or回滚)
                String noticeCommand = "";
                List<JSONObject> result = new LinkedList<>();
                if (groupStatus.get(groupId).contains("ROLLBACK")){
                    noticeCommand = "rollback";
                }else {
                    noticeCommand = "commit";
                }
                sendResult(groupId,noticeCommand);
            }
        }
    }

    private void createGroup(String groupId) {
        groupTransactions.put(groupId,new LinkedList<>());
        groupStatus.put(groupId,new LinkedList<>());
    }

    private void addGroupTransactions(String groupId, JSONObject jsonObject) {
        if (groupTransactions.get(groupId)==null){
            groupTransactions.put(groupId,new LinkedList<>());
        }
        groupTransactions.get(groupId).add(jsonObject);
    }

    private void addGroupStatus(String groupId, String transactionType) {
        if (groupStatus.get(groupId)==null){
            groupStatus.put(groupId,new LinkedList<>());
        }
        groupStatus.get(groupId).add(transactionType);
    }

    private void addChannelMap(String groupId,String transactionId, Channel channel) {
        if (channelGroupMap.get(groupId)==null){
            channelGroupMap.put(groupId,new HashMap<>());
        }
        channelGroupMap.get(groupId).put(transactionId,channel);
    }

    /**
     * 通知本地事务
     * 这里就不考虑发送失败等情况了
     */
    private void sendResult(String groupId, String noticeCommand) {
        Map<String, Channel> channels = channelGroupMap.get(groupId);
        for (Map.Entry<String, Channel> entry : channels.entrySet()) {
            JSONObject jo = new JSONObject();
            jo.put("groupId",groupId);
            jo.put("noticeCommand",noticeCommand);
            jo.put("transactionId",entry.getKey());
            ChannelFuture channelFuture = entry.getValue().writeAndFlush(jo.toJSONString());
            System.out.println(channelFuture);
        }
        //释放资源
        channelGroupMap.remove(groupId);
        groupTransactions.remove(groupId);
        groupStatus.remove(groupId);
        endGroupMap.remove(groupId);
        countTransGroupMap.remove(groupId);
    }
}

至此,分布式事务框架已完成,A、B服务出现异常,数据库都会回滚。但是不能用于生产环境,还有许多需要优化的地方,这里是为了描述工作原理。

你可能感兴趣的:(java,分布式)