Spring Security Oauth2 存储Token的方式有多种, 比如JWT、Jdbc(数据库)、Redis等,但是对于一个大型的分布式服务应用,Redis存储方式应该是最佳选择。
我们使用默认的Redis存储方式,序列化到到Redis的数据是采用JDK序列化策略写入到redis的。这样对于程序的功能毫无影响,但是对于开发者却很不直观,出现问题,也不容易排查,我们能不能把它们序列化成JSON格式呢?
Spring Security Oauth2 Redis序列化Token相关的数据是采用JdkSerializationStrategy,用这个序列化策略序列化出的结果正如上图所示那样,具体的代码如下:
//org.springframework.security.oauth2.provider.token.store.redis.RedisTokenStore
public class RedisTokenStore implements TokenStore {
private static final String ACCESS = "access:";
private static final String AUTH_TO_ACCESS = "auth_to_access:";
private static final String AUTH = "auth:";
private static final String REFRESH_AUTH = "refresh_auth:";
private static final String ACCESS_TO_REFRESH = "access_to_refresh:";
private static final String REFRESH = "refresh:";
private static final String REFRESH_TO_ACCESS = "refresh_to_access:";
private static final String CLIENT_ID_TO_ACCESS = "client_id_to_access:";
private static final String UNAME_TO_ACCESS = "uname_to_access:";
private static final boolean springDataRedis_2_0 = ClassUtils.isPresent(
"org.springframework.data.redis.connection.RedisStandaloneConfiguration",
RedisTokenStore.class.getClassLoader());
private final RedisConnectionFactory connectionFactory;
private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
//Jdk序列方式
private RedisTokenStoreSerializationStrategy serializationStrategy = new JdkSerializationStrategy();
......省略无关代码
笔者首先想到的是自定义一个RedisTokenStoreSerializationStrategy**,**接口的实现改成Object和JSONString之间的序列化、反序列应该就好了,于是笔者使用Fastjson实现了一个序列化策略,并注入到Spring Bean容器中,如下:
public class FastjsonRedisTokenStoreSerializationStrategy implements RedisTokenStoreSerializationStrategy {
private final static ParserConfig defaultRedisConfig = new ParserConfig();
static {
defaultRedisConfig.setAutoTypeSupport(true);
}
static {
//设置Fastjson Json自动转换为Java对象
ParserConfig.getGlobalInstance().setAutoTypeSupport(true);
}
@Override
public <T> T deserialize(byte[] bytes, Class<T> clazz) {
Preconditions.checkArgument(clazz != null,
"clazz can't be null");
if (bytes == null || bytes.length == 0) {
return null;
}
try {
return JSON.parseObject(new String(bytes, IOUtils.UTF8), clazz, defaultRedisConfig);
} catch (Exception ex) {
throw new SerializationException("Could not serialize: " + ex.getMessage(), ex);
}
}
@Override
public String deserializeString(byte[] bytes) {
if (bytes == null || bytes.length == 0) {
return null;
}
return new String(bytes, IOUtils.UTF8);
}
@Override
public byte[] serialize(Object object) {
if (object == null) {
return new byte[0];
}
try {
return JSON.toJSONBytes(object, SerializerFeature.WriteClassName,
SerializerFeature.DisableCircularReferenceDetect);
} catch (Exception ex) {
throw new SerializationException("Could not serialize: " + ex.getMessage(), ex);
}
}
@Override
public byte[] serialize(String data) {
if (data == null || data.length() == 0) {
return new byte[0];
}
return data.getBytes(Charset.forName("utf-8"));
}
}
@Bean
public RedisTokenStore redisTokenStore(){
RedisTokenStore store = new RedisTokenStore(redisConnectionFactory);
store.setSerializationStrategy(new FastjsonRedisTokenStoreSerializationStrategy());
return store;
}
笔者刚开始以为这样处理就可以了,OAuth2Authentication 这个类并没有默认构造方法,会导致Fastjson反序列化失败,因此针对Oauth2等无法序列化或者反序列化的类需要特殊化处理。针对DefaultOAuth2RefreshToken、OAuth2Authentication等特殊的类,我们需要定制化序列化以及反序列化策略。下面是笔者项目的配置以及序列化和反序列策略,有些类是满足需求定制的,不在Spring Security包中。提供只是参考思路,不可照抄。配置和自定义策略如下:
public class FastjsonRedisTokenStoreSerializationStrategy implements RedisTokenStoreSerializationStrategy {
private static ParserConfig config = new ParserConfig();
static {
init();
}
protected static void init() {
//自定义oauth2序列化:DefaultOAuth2RefreshToken 没有setValue方法,会导致JSON序列化为null
config.setAutoTypeSupport(true);
config.putDeserializer(DefaultOAuth2RefreshToken.class, new DefaultOauth2RefreshTokenSerializer());
config.putDeserializer(OAuth2Authentication.class, new OAuth2AuthenticationSerializer());
config.addAccept("org.springframework.security.oauth2.provider.");
config.addAccept("org.springframework.security.oauth2.provider.client");
TypeUtils.addMapping("org.springframework.security.oauth2.provider.OAuth2Authentication",
OAuth2Authentication.class);
TypeUtils.addMapping("org.springframework.security.oauth2.provider.client.BaseClientDetails",
BaseClientDetails.class);
}
......
}
public class OAuth2AuthenticationSerializer implements ObjectDeserializer {
@Override
public <T> T deserialze(DefaultJSONParser parser, Type type, Object fieldName) {
if (type == OAuth2Authentication.class) {
try {
Object o = parse(parser);
if (o == null) {
return null;
} else if (o instanceof OAuth2Authentication) {
return (T) o;
}
JSONObject jsonObject = (JSONObject) o;
OAuth2Request request = parseOAuth2Request(jsonObject);
UsernamePasswordAuthenticationToken authentication = jsonObject
.getObject("userAuthentication", UsernamePasswordAuthenticationToken.class);
return (T) new OAuth2Authentication(request, authentication);
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
return null;
}
private OAuth2Request parseOAuth2Request(JSONObject jsonObject) {
JSONObject json = jsonObject.getObject("oAuth2Request", JSONObject.class);
Map<String, String> requestParameters = json.getObject("requestParameters", Map.class);
String clientId = json.getString("clientId");
String grantType = json.getString("grantType");
String redirectUri = json.getString("redirectUri");
Boolean approved = json.getBoolean("approved");
Set<String> responseTypes = json
.getObject("responseTypes", new TypeReference<HashSet<String>>() {
});
Set<String> scope = json.getObject("scope", new TypeReference<HashSet<String>>() {
});
Set<String> authorities = json.getObject("authorities", new TypeReference<HashSet<String>>() {
});
Set<GrantedAuthority> grantedAuthorities = new HashSet<>(0);
if (authorities != null && !authorities.isEmpty()) {
authorities.forEach(s -> grantedAuthorities.add(new SimpleGrantedAuthority(s)));
}
Set<String> resourceIds = json
.getObject("resourceIds", new TypeReference<HashSet<String>>() {
});
Map<String, Serializable> extensions = json
.getObject("extensions", new TypeReference<HashMap<String, Serializable>>() {
});
OAuth2Request request = new OAuth2Request(requestParameters, clientId,
grantedAuthorities, approved, scope, resourceIds, redirectUri, responseTypes, extensions);
TokenRequest tokenRequest = new TokenRequest(requestParameters, clientId, scope, grantType);
request.refresh(tokenRequest);
return request;
}
@Override
public int getFastMatchToken() {
return 0;
}
private Object parse(DefaultJSONParser parse) {
JSONObject object = new JSONObject(parse.lexer.isEnabled(Feature.OrderedField));
Object parsedObject = parse.parseObject((Map) object);
if (parsedObject instanceof JSONObject) {
return (JSONObject) parsedObject;
} else if (parsedObject instanceof OAuth2Authentication) {
return parsedObject;
} else {
return parsedObject == null ? null : new JSONObject((Map) parsedObject);
}
}
}
public class DefaultOauth2RefreshTokenSerializer implements ObjectDeserializer {
@Override
public <T> T deserialze(DefaultJSONParser parser, Type type, Object fieldName) {
if (type == DefaultOAuth2RefreshToken.class) {
JSONObject jsonObject = parser.parseObject();
String tokenId = jsonObject.getString("value");
DefaultOAuth2RefreshToken refreshToken = new DefaultOAuth2RefreshToken(tokenId);
return (T) refreshToken;
}
return null;
}
@Override
public int getFastMatchToken() {
return 0;
}
}
能用Fastjson当然也可以使用Jackson,而且Spring Security也默认加了一下Jackson的相关注解,只不过Jackson与Fastjson本就功能类似,但实现完全不同的序列化工具。配置、序列化策略、以及序列化结果如下:
public class JacksonRedisTokenStoreSerializationStrategy implements RedisTokenStoreSerializationStrategy {
protected static ObjectMapper mapper = new ObjectMapper();
{
SimpleModule module = new SimpleModule();
module.addDeserializer(OAuth2Authentication.class,
new OAuth2AuthenticationJackson2Deserializer(OAuth2Authentication.class));
mapper.registerModule(module);
}
@Override
public <T> T deserialize(byte[] bytes, Class<T> clazz) {
Preconditions.checkArgument(clazz != null,
"clazz can't be null");
if (bytes == null || bytes.length == 0) {
return null;
}
try {
return mapper.readValue(new String(bytes, IOUtils.UTF8), clazz);
} catch (Exception ex) {
throw new SerializationException("Could not serialize: " + ex.getMessage(), ex);
}
}
@Override
public String deserializeString(byte[] bytes) {
if (bytes == null || bytes.length == 0) {
return null;
}
return new String(bytes, IOUtils.UTF8);
}
@Override
public byte[] serialize(Object object) {
if (object == null) {
return new byte[0];
}
try {
return mapper.writeValueAsBytes(object);
} catch (Exception ex) {
throw new SerializationException("Could not serialize: " + ex.getMessage(), ex);
}
}
@Override
public byte[] serialize(String data) {
if (data == null || data.length() == 0) {
return new byte[0];
}
return data.getBytes(Charset.forName("utf-8"));
}
}
public class OAuth2AuthenticationJackson2Deserializer extends StdDeserializer<OAuth2Authentication> {
protected OAuth2AuthenticationJackson2Deserializer(Class<?> vc) {
super(vc);
}
private static String readString(ObjectMapper mapper, JsonNode jsonNode) {
return readValue(mapper, jsonNode, String.class);
}
private static <T> T readValue(ObjectMapper mapper, JsonNode jsonNode, Class<T> clazz) {
if (mapper == null || jsonNode == null || clazz == null) {
return null;
}
try {
return mapper.readValue(jsonNode.traverse(), clazz);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static <T> T readValue(ObjectMapper mapper, JsonNode jsonNode, TypeReference<T> type) {
if (mapper == null || jsonNode == null || type == null) {
return null;
}
try {
return mapper.readValue(jsonNode.traverse(), type);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public OAuth2Authentication deserialize(JsonParser jp, DeserializationContext ctxt)
throws IOException, JsonProcessingException {
ObjectMapper mapper = (ObjectMapper) jp.getCodec();
JsonNode jsonNode = mapper.readTree(jp);
JsonNode requestNode = readJsonNode(jsonNode, "oauth2Request");
JsonNode userAuthenticationNode = readJsonNode(jsonNode, "userAuthentication");
Authentication authentication = parseAuthentication(mapper, userAuthenticationNode);
OAuth2Request request = parseOAuth2Request(mapper, requestNode);
return new OAuth2Authentication(request, authentication);
}
private Authentication parseAuthentication(ObjectMapper mapper, JsonNode json) {
if (mapper == null || json == null) {
return null;
}
Oauth2User principal = parseOAuth2User(mapper, json.get("principal"));
Object credentials = readValue(mapper, json.get("credentials"), Object.class);
Set<SimpleGrantedAuthority> grantedAuthorities = parseSimpleGrantedAuthorities(mapper, json.get("authorities"));
return new UsernamePasswordAuthenticationToken(principal, credentials, grantedAuthorities);
}
private Oauth2User parseOAuth2User(ObjectMapper mapper, JsonNode json) {
if (mapper == null || json == null) {
return null;
}
String username = readString(mapper, json.get("username"));
String password = readString(mapper, json.get("password"));
String nickname = readString(mapper, json.get("nickname"));
String phone = readString(mapper, json.get("phone"));
String email = readString(mapper, json.get("email"));
String avatar = readString(mapper, json.get("avatar"));
List<Map<String, Object>> roleList = readValue(mapper, json.get("roleList"),
new TypeReference<List<Map<String, Object>>>() {
});
List<Map<String, Object>> permissionList = readValue(mapper, json.get("permissionList"),
new TypeReference<List<Map<String, Object>>>() {
});
Boolean accountNonExpired = readValue(mapper, json.get("accountNonExpired"), Boolean.class);
Boolean accountNonLocked = readValue(mapper, json.get("accountNonLocked"), Boolean.class);
Boolean credentialsNonExpired = readValue(mapper, json.get("credentialsNonExpired"), Boolean.class);
Boolean enabled = readValue(mapper, json.get("enabled"), Boolean.class);
Set<SimpleGrantedAuthority> grantedAuthorities = parseSimpleGrantedAuthorities(mapper, json.get("authorities"));
return Oauth2User.builder()
.username(username)
.password(password)
.nickname(nickname)
.phone(phone)
.email(email)
.avatar(avatar)
.accountNonExpired(accountNonExpired)
.accountNonLocked(accountNonLocked)
.credentialsNonExpired(credentialsNonExpired)
.enabled(enabled)
.authorities(grantedAuthorities)
.roleList(roleList)
.permissionList(permissionList).build();
}
private OAuth2Request parseOAuth2Request(ObjectMapper mapper, JsonNode json) {
if (mapper == null || json == null) {
return null;
}
HashMap<String, String> requestParameters = readValue(mapper, json.get("requestParameters"), HashMap.class);
String clientId = readString(mapper, json.get("clientId"));
String grantType = readString(mapper, json.get("grantType"));
String redirectUri = readString(mapper, json.get("redirectUri"));
Boolean approved = readValue(mapper, json.get("approved"), Boolean.class);
Set<String> responseTypes = readValue(mapper, json.get("responseTypes"), Set.class);
Set<String> scope = readValue(mapper, json.get("scope"), Set.class);
Set<String> resourceIds = readValue(mapper, json.get("resourceIds"), Set.class);
Map<String, Serializable> extensions = readValue(mapper, json.get("extensions"),
new TypeReference<Map<String, Serializable>>() {
});
Set<SimpleGrantedAuthority> grantedAuthorities = parseSimpleGrantedAuthorities(mapper, json.get("authorities"));
OAuth2Request request = new OAuth2Request(requestParameters, clientId,
grantedAuthorities, approved, scope, resourceIds, redirectUri, responseTypes, extensions);
TokenRequest tokenRequest = new TokenRequest(requestParameters, clientId, scope, grantType);
request.refresh(tokenRequest);
return request;
}
private Set<SimpleGrantedAuthority> parseSimpleGrantedAuthorities(ObjectMapper mapper, JsonNode json) {
Set<LinkedHashMap<String, String>> authorities = readValue(mapper, json, Set.class);
Set<SimpleGrantedAuthority> grantedAuthorities = new HashSet<>(0);
if (authorities != null && !authorities.isEmpty()) {
authorities.forEach(s -> {
if (s != null && !s.isEmpty()) {
grantedAuthorities.add(new SimpleGrantedAuthority(s.get("authority")));
}
});
}
return grantedAuthorities;
}
private JsonNode readJsonNode(JsonNode jsonNode, String field) {
return jsonNode.has(field) ? jsonNode.get(field) : MissingNode.getInstance();
}
}