websocket校验token:使用threadlocal存放和获取当前登录用户

都知道threadlocal可以用于线程之间的变量隔离,在登录时中它可以放入当前用户,之后再用于获取当前登录用户,下面是一个使用实例。
用户实体类:(jpa框架)

@Data
@EqualsAndHashCode(callSuper = false)
@TableName("sys_user")
public class SysUser extends SuperEntity {
	private static final long serialVersionUID = -5886012896705137070L;

	private String username;
	private String password;
	private String nickname;
	private String headImgUrl;
	private String mobile;
	private Integer sex;
	private Boolean enabled;
	private String type;
	private String openId;
	@TableLogic
	private boolean isDel;
}

threadlocal类:

public class LoginUserContextHolder {
    private static final ThreadLocal<SysUser> CONTEXT = new TransmittableThreadLocal<>();

    public static void setUser(SysUser user) {
        CONTEXT.set(user);
    }

    public static SysUser getUser() {
        return CONTEXT.get();
    }

    public static void clear() {
        CONTEXT.remove();
    }
}

防止用户到threadlocal中:

import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.common.exceptions.UnapprovedClientAuthenticationException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.TokenStore;

import javax.servlet.http.HttpServletRequest;
import java.nio.charset.StandardCharsets;
import java.util.*;

public class AuthUtils {
   /**
     * 校验accessToken
     */
    public static SysUser checkAccessToken(HttpServletRequest request) {
        String accessToken = extractToken(request);
        return checkAccessToken(accessToken);
    }

    public static SysUser checkAccessToken(String accessTokenValue) {
        TokenStore tokenStore = SpringUtil.getBean(TokenStore.class);
        OAuth2AccessToken accessToken = tokenStore.readAccessToken(accessTokenValue);
        if (accessToken == null || accessToken.getValue() == null) {
            throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
        } else if (accessToken.isExpired()) {
            tokenStore.removeAccessToken(accessToken);
            throw new InvalidTokenException("Access token expired: " + accessTokenValue);
        }
        OAuth2Authentication result = tokenStore.readAuthentication(accessToken);
        if (result == null) {
            throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
        }
        return setContext(result);
    }
    /**
     * 用户信息赋值 context 对象
     */
    public static SysUser setContext(Authentication authentication) {
        SecurityContextHolder.getContext().setAuthentication(authentication);
        SysUser user = getUser(authentication);
        LoginUserContextHolder.setUser(user);
        return user;
    }

    /**
     * *从header 请求中的clientId:clientSecret
     */
    public static String[] extractClient(HttpServletRequest request) {
        String header = request.getHeader("Authorization");
        if (header == null || !header.startsWith(BASIC_)) {
            throw new UnapprovedClientAuthenticationException("请求头中client信息为空");
        }
        return extractHeaderClient(header);
    }
   /**
     * 从header 请求中的clientId:clientSecret
     *
     * @param header header中的参数
     */
    public static String[] extractHeaderClient(String header) {
        byte[] base64Client = header.substring(BASIC_.length()).getBytes(StandardCharsets.UTF_8);
        byte[] decoded = Base64.getDecoder().decode(base64Client);
        String clientStr = new String(decoded, StandardCharsets.UTF_8);
        String[] clientArr = clientStr.split(":");
        if (clientArr.length != 2) {
            throw new RuntimeException("Invalid basic authentication token");
        }
        return clientArr;
    }

获取当前登录人:

    /**
     * 通过 LoginUserContextHolder 获取当前登录人
     */
    @GetMapping("/test/auth2")
    public String auth() {
        return "auth2:" + LoginUserContextHolder.getUser().getUsername();
    }

websocket鉴权:

import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.websocket.server.ServerEndpointConfig;
public class WcAuthConfigurator extends ServerEndpointConfig.Configurator {
//checkOrigin:校验token
    @Override
    public boolean checkOrigin(String originHeaderValue) {
        ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        try {
            //检查token有效性
            AuthUtils.checkAccessToken(servletRequestAttributes.getRequest());
        } catch (Exception e) {
            log.error("WebSocket-auth-error", e);
            return false;
        }
        return super.checkOrigin(originHeaderValue);
    }
}

在AuthUtils.checkAccessToken方法内部最终执行了threadelocal的put方法
使用WcAuthConfigurator :
//@ServerEndpoint:
主要是将目前的类定义成一个websocket服务器端, 注解的值将被用于监听用户连接的终端访问URL地址,客户端可以通过这个URL来连接到WebSocket服务器端,在这里配置configurator属性为刚刚写的配置类


@Slf4j
@Component
@ServerEndpoint(value = "/websocket/test", configurator = WcAuthConfigurator.class)
public class TestWebSocketController {
    @OnOpen
    public void onOpen(Session session) throws IOException {
        session.getBasicRemote().sendText("TestWebSocketController-ok");
    }
}

你可能感兴趣的:(项目,websocket,java,spring)