Shiro整合JWT

5/10/2022 安全框架

# Shiro整合Jwt的实现方案

# 1、引入Jwt依赖🚨

<!-- jwt -->
<dependency>
  <groupId>com.auth0</groupId>
  <artifactId>java-jwt</artifactId>
  <version>3.4.1</version>
</dependency>
1
2
3
4
5
6

# 2、添加配置 ShiroConfig 🕸️

@Configuration
public class ShiroConfig 
{
    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Value("${spring.redis.host}")
    private String redisHost;

    @Value("${spring.redis.port}")
    private Integer redisPort;

    private static final Integer expireAt = 1800;

    private static final Integer timeout = 3000;

    @Value("${spring.redis.password}")
    private String redisPassword;

    @Bean
    public ShiroFilterFactoryBean shirFilter(SecurityManager securityManager) {
        ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
        // 设置登录URL以及无权限URL 
        String prefix = "/system";
        shiroFilterFactoryBean.setLoginUrl(prefix + "/Login");
        shiroFilterFactoryBean.setUnauthorizedUrl(prefix + "/Unauthorized");
        // 设置过滤器
        Map<String, String> filterChainDefinitionMap = new LinkedHashMap<>();
        LinkedHashMap<String, Filter> filters = new LinkedHashMap<>();
        //设置自定义的JWT过滤器
        filters.put("jwt", new JwtFilter());
        shiroFilterFactoryBean.setFilters(filters);
        shiroFilterFactoryBean.setSecurityManager(securityManager);

        // 所有请求都要经过 jwt过滤器
        filterChainDefinitionMap.put("/**", "jwt");
        shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);
        return shiroFilterFactoryBean;
    }

    @Bean
    public AuthorizationAttributeSourceAdvisor authorizationAttributeSourceAdvisor(SecurityManager securityManager) {
        AuthorizationAttributeSourceAdvisor authorizationAttributeSourceAdvisor = new AuthorizationAttributeSourceAdvisor();
        authorizationAttributeSourceAdvisor.setSecurityManager(securityManager);
        return authorizationAttributeSourceAdvisor;
    }

    /**
     * 注入 securityManager
     */
    @Bean
    public SecurityManager securityManager() {
        DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
        // 设置realm.
        securityManager.setRealm(customRealm());

        // 设置缓存
        securityManager.setCacheManager(cacheManager());

        // 设置会话
        securityManager.setSessionManager(sessionManager());
        return securityManager;
    }

    /**
     * 自定义身份认证 realm;
     */
    @Bean
    public CustomRealm customRealm() {
        return new CustomRealm();
    }


    /**
     * 加入redis缓存,避免重复从数据库获取数据
     */
    public RedisManager redisManager() {
        RedisManager redisManager = new RedisManager();
        redisManager.setHost(redisHost);
        redisManager.setPort(redisPort);
        redisManager.setPassword(redisPassword);
        redisManager.setExpire(expireAt);
        redisManager.setTimeout(timeout);
        return redisManager;
    }

    public RedisCacheManager cacheManager() {
        RedisCacheManager redisCacheManager = new RedisCacheManager();
        redisCacheManager.setRedisManager(redisManager());
        return redisCacheManager;
    }


    /**
     * session 会话管理
     */
    @Bean
    public RedisSessionDAO sessionDAO() {
        RedisSessionDAO redisSessionDAO = new RedisSessionDAO();
        redisSessionDAO.setRedisManager(redisManager());
        return redisSessionDAO;
    }

    @Bean
    public SimpleCookie sessionIdCookie(){
        SimpleCookie cookie = new SimpleCookie("X-Token");
        cookie.setMaxAge(-1);
        cookie.setPath("/");
        cookie.setHttpOnly(false);
        return cookie;
    }

    @Bean
    public SessionManager sessionManager() {
        DefaultWebSessionManager sessionManager = new DefaultWebSessionManager();
        sessionManager.setSessionIdCookie(sessionIdCookie());
        sessionManager.setSessionIdCookieEnabled(true);
        Collection<SessionListener> listeners = new ArrayList<SessionListener>();
        listeners.add(new ShiroSessionListener());
        sessionManager.setSessionListeners(listeners);
        sessionManager.setSessionDAO(sessionDAO());
        return sessionManager;
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

# 3、实现JwtFilter过滤器🎒

public class JwtFilter extends BasicHttpAuthenticationFilter {
    
    private Logger log = LoggerFactory.getLogger(this.getClass());
    
    private static final String TOKEN = "Authorization";
    
    private AntPathMatcher pathMatcher = new AntPathMatcher();
    
    
    @Override
    protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) throws UnauthorizedException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        
        // 白名单处理逻辑代码
        // if(url)
        // {
        //    return true;
        // }

        //判断请求的请求头是否带上Token
        if (isLoginAttempt(request, response)) {
            //如果存在,则进入 executeLogin 方法执行登入,检查 token 是否正确
            return executeLogin(request, response);
        }
        
        log.error("未传token {}", httpServletRequest.getRequestURI());
        return false;
    }
    
    /**
     * 判断用户是否想要登入。
     * 检测 header 里面是否包含 Token 字段
     */
    @Override
    protected boolean isLoginAttempt(ServletRequest request, ServletResponse response) {
        HttpServletRequest req = (HttpServletRequest) request;
        String token = req.getHeader(TOKEN);
        return token != null;
    }
    
    /**
     * 执行登陆操作
     */
    @Override
    protected boolean executeLogin(ServletRequest request, ServletResponse response) {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        String token = httpServletRequest.getHeader(TOKEN);
        JwtToken jwtToken = new JwtToken(token);
        try {
            // 提交给realm进行登入,如果错误他会抛出异常并被捕获
            getSubject(request, response).login(jwtToken);
            return true;
        } catch (Exception e) {
            request.setAttribute("fail", e.getMessage());
            log.error("executeLogin {}", e.getMessage());
            return false;
        }
    }
    
    /**
    * 对跨域提供支持(注意生产)
    */
    @Override
    protected boolean preHandle(ServletRequest request, ServletResponse response) throws Exception {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        HttpServletResponse httpServletResponse = (HttpServletResponse) response;
        httpServletResponse.setHeader("Access-control-Allow-Origin", "*");
        httpServletResponse.setHeader("Access-Control-Allow-Methods", "GET,POST,OPTIONS");
        httpServletResponse.setHeader("Access-Control-Allow-Headers", httpServletRequest.getHeader("Access-Control-Request-Headers"));
        if (httpServletRequest.getMethod().equals(RequestMethod.OPTIONS.name())) {
            httpServletResponse.setStatus(HttpStatus.OK.value());
            return false;
        }
        return super.preHandle(request, response);
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

# 4、实现 JwtToken🎱

public class JwtToken implements AuthenticationToken {

    private static final long serialVersionUID = 1282057025599826155L;

    private String token;

    private String expireAt;

    public JwtToken(String token) {
        this.token = token;
    }

    public JwtToken(String token, String expireAt) {
        this.token = token;
        this.expireAt = expireAt;
    }

    @Override
    public Object getPrincipal() {
        return token;
    }

    @Override
    public Object getCredentials() {
        return token;
    }

    public String getToken() {
        return token;
    }

    public void setToken(String token) {
        this.token = token;
    }

    public String getExpireAt() {
        return expireAt;
    }

    public void setExpireAt(String expireAt) {
        this.expireAt = expireAt;
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

# 5、封装jwt工具类🥸

public class JwtUtil {

    private static Logger log = LoggerFactory.getLogger(JwtUtil.class);

    // 设置过期时间
    private static final long EXPIRE_TIME = 1000 * 60 * 30;

    // 设置秘钥:一般建议写在Yaml中
    private static final String Secret = "28ca017de15a57e206f0";

    /**
     * 校验 token是否正确
     *
     * @param token token
     * @param user  用户
     * @return 是否正确
     */
    public static boolean verify(String token, User user) {
        try {
            Algorithm algorithm = Algorithm.HMAC256(Secret);
            JWTVerifier verifier = JWT.require(algorithm)
                    .withClaim("userId", user.getId())
                    .withClaim("roleId", user.getRole())
                    .build();
            verifier.verify(token);
            return true;
        } catch (Exception e) {
            log.error("token is invalid{}", e.getMessage());
            return false;
        }
    }

    /**
     * 从 token中获取用户id
     *
     * @return token中包含的用户id
     */
    public static String getUserId(String token) {
        try {
            DecodedJWT jwt = JWT.decode(token);
            return jwt.getClaim("userId").asString();
        } catch (JWTDecodeException e) {
            log.error("error:{}", e.getMessage());
            return null;
        }
    }

    /**
     * 从 token中获取用户roleId
     *
     * @return token中包含的用户id
     */
    public static Integer getRoleId(String token) {
        try {
            DecodedJWT jwt = JWT.decode(token);
            return jwt.getClaim("roleId").asInt();
        } catch (JWTDecodeException e) {
            log.error("error:{}", e.getMessage());
            return null;
        }
    }

    /**
     * 生成 token
     *
     * @param user
     * @return token
     */
    public static String sign(User user) {
        try {
            Date date = new Date(System.currentTimeMillis() + EXPIRE_TIME);
            Algorithm algorithm = Algorithm.HMAC256(Secret);

            // 生成Token并绑定用户信息
            return JWT.create()
                    .withClaim("userId", user.getId())
                    .withClaim("roleId", user.getRole())
                    .withExpiresAt(date)
                    .sign(algorithm);
        } catch (Exception e) {
            log.error("error:{}", e);
            return null;
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

# 6、Realm实现🧐

public class CustomRealm extends AuthorizingRealm {

    @Override
    public boolean supports(AuthenticationToken token) {
        return token instanceof JwtToken;
    }

    /**
     * 授权模块,获取用户角色和权限
     * @param token token
     * @return AuthorizationInfo 权限信息
     */
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection token) {
        SimpleAuthorizationInfo simpleAuthorizationInfo = new SimpleAuthorizationInfo();
        String userId = JwtUtil.getUserId(token.toString());
        if(userId == null) {
            return simpleAuthorizationInfo;
        }

        String userRole = UserMock.getRoleById(userId);
        Set<String> role = new HashSet<>();
        role.add(userRole);
        simpleAuthorizationInfo.setRoles(role);
        simpleAuthorizationInfo.setStringPermissions(role);
        return simpleAuthorizationInfo;
    }

    /**
     * 用户认证
     *
     * @param authenticationToken 身份认证 token
     * @return AuthenticationInfo 身份认证信息
     * @throws AuthenticationException 认证相关异常
     */
    @Override
    protected SimpleAuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) throws AuthenticationException {
        String token = (String) authenticationToken.getCredentials();
        String userId = JwtUtil.getUserId(token);
        if (StringUtils.isBlank(userId)) {
            throw new AuthenticationException("验证失败");
        }

        String userRole = UserMock.getRoleById(userId);
        User userBean = new User();
        userBean.setUserId(userId);
        userBean.setRole(userRole);
        if (!JwtUtil.verify(token, userBean)) {
            throw new AuthenticationException("token失效");
        }
        return new SimpleAuthenticationInfo(token, token, "shiroJwtRealm");
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53