Accessing Authenticated User Principal In SessionConnectedEvent With JWT Authentication

by ADMIN 88 views
Iklan Headers

When building real-time applications using Spring WebSockets and STOMP, securing your WebSocket endpoints is crucial. JSON Web Tokens (JWT) are a popular choice for authentication due to their stateless nature and ease of use. In this article, we will explore a common scenario: how to access the authenticated user principal within a SessionConnectedEvent when using a ChannelInterceptor to handle JWT authentication. We'll delve into the intricacies of setting up a custom ChannelInterceptor, extracting and validating the JWT, and propagating the authentication information so it can be accessed later in the lifecycle of the WebSocket session. This guide will help you understand the process of securing your WebSocket connections and managing user authentication effectively.

Understanding the Problem

The primary challenge lies in the fact that the SessionConnectedEvent is triggered after the STOMP CONNECT command has been processed. While your ChannelInterceptor successfully authenticates the user during the CONNECT command processing, the authentication information might not be readily available in the context of the SessionConnectedEvent. This is because the event is handled in a different execution context, and the authentication details need to be explicitly propagated. We need to find a mechanism to carry over the authenticated principal from the ChannelInterceptor to the event handler.

Specifically, consider a scenario where you're using a custom WebSocketChannelInterceptor (implementing ChannelInterceptor) to handle authentication during the CONNECT STOMP command. The interceptor's role is to extract and validate a JWT token from the headers of the STOMP CONNECT message. Upon successful validation, the interceptor sets the Authentication object in the message headers. However, when the SessionConnectedEvent is triggered, you need to access this authenticated user principal to perform tasks such as associating the session with the user or logging connection events. The difficulty arises because the SessionConnectedEvent handler doesn't automatically have access to the Authentication object set in the interceptor.

To solve this, we'll explore techniques for storing the authentication information in a place accessible to both the interceptor and the event handler. This usually involves leveraging the session attributes or message headers to propagate the Authentication object. By the end of this article, you will understand how to effectively bridge this gap and ensure that the authenticated user principal is readily available when handling SessionConnectedEvents.

Setting Up the JWT Authentication with ChannelInterceptor

Creating a Custom ChannelInterceptor

The first step is to create a custom ChannelInterceptor that intercepts the STOMP CONNECT messages. This interceptor will be responsible for extracting the JWT token, validating it, and setting the Authentication object in the message headers. Here’s how you can define a custom ChannelInterceptor:

import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.Optional;

@Component
public class JwtChannelInterceptor implements ChannelInterceptor {

    private final JwtTokenProvider jwtTokenProvider;

    public JwtChannelInterceptor(JwtTokenProvider jwtTokenProvider) {
        this.jwtTokenProvider = jwtTokenProvider;
    }

    @Override
    public Message<?> preSend(Message<?> message, MessageChannel channel) {
        StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);

        if (StompCommand.CONNECT.equals(accessor.getCommand())) {
            Optional.ofNullable(accessor.getFirstNativeHeader("Authorization"))
                    .map(token -> token.replace("Bearer ", ""))
                    .filter(jwtTokenProvider::validateToken)
                    .ifPresent(token -> {
                        String username = jwtTokenProvider.getUsernameFromToken(token);
                        Authentication authentication = new UsernamePasswordAuthenticationToken(
                                username,
                                null,
                                List.of(new SimpleGrantedAuthority("ROLE_USER"))
                        );
                        accessor.setUser(authentication);
                    });
        }
        return message;
    }
}

In this interceptor, the preSend method is overridden to intercept messages before they are sent to the message broker. It checks for the CONNECT command, extracts the JWT token from the Authorization header, and validates it using a JwtTokenProvider. If the token is valid, an Authentication object is created and set as the user principal in the StompHeaderAccessor. This ensures that the user is authenticated before establishing the WebSocket connection. The key here is setting the Authentication object using accessor.setUser(authentication), which makes it available within the scope of the STOMP session.

Configuring WebSocket Message Broker

Next, you need to configure the WebSocket message broker to use the custom ChannelInterceptor. This is typically done in your WebSocket configuration class. Here’s how you can configure it:

import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;

@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

    private final JwtChannelInterceptor jwtChannelInterceptor;

    public WebSocketConfig(JwtChannelInterceptor jwtChannelInterceptor) {
        this.jwtChannelInterceptor = jwtChannelInterceptor;
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry config) {
        config.enableSimpleBroker("/topic");
        config.setApplicationDestinationPrefixes("/app");
    }

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        registry.addEndpoint("/ws").withSockJS();
    }

    @Override
    public void configureClientInboundChannel(ChannelRegistration registration) {
        registration.interceptors(jwtChannelInterceptor);
    }
}

In the WebSocketConfig class, the configureClientInboundChannel method is overridden to add the JwtChannelInterceptor to the inbound channel. This ensures that every message coming from the client is intercepted by the JWT authentication logic. By registering the custom interceptor, you ensure that all incoming messages are checked for a valid JWT, thereby securing your WebSocket communication.

JWT Token Provider

The JwtTokenProvider is a crucial component responsible for validating the JWT token and extracting information from it. Here’s a basic implementation:

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.util.Date;

@Component
public class JwtTokenProvider {

    @Value("${jwt.secret}")
    private String jwtSecret;

    public String getUsernameFromToken(String token) {
        return getClaimsFromToken(token).getSubject();
    }

    public boolean validateToken(String token) {
        try {
            Jwts.parserBuilder().setSigningKey(getSigningKey()).build().parseClaimsJws(token);
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    private Claims getClaimsFromToken(String token) {
        return Jwts.parserBuilder().setSigningKey(getSigningKey()).build().parseClaimsJws(token).getBody();
    }

    private SecretKey getSigningKey() {
        byte[] keyBytes = jwtSecret.getBytes(StandardCharsets.UTF_8);
        return Keys.hmacShaKeyFor(keyBytes);
    }
}

The JwtTokenProvider class provides methods for extracting the username from the token and validating the token’s signature and expiration. The validateToken method uses the Jwts library to parse the token and verify its signature. If the token is valid, the method returns true; otherwise, it returns false. This validation is essential to ensure that only authenticated users can establish WebSocket connections.

Accessing Authenticated User in SessionConnectedEvent

Listening for SessionConnectedEvent

To access the authenticated user in a SessionConnectedEvent, you need to listen for this event using a Spring event listener. This allows you to execute custom logic when a new WebSocket session is established. Here’s how you can set up an event listener:

import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;

import java.util.Objects;

@Component
public class WebSocketEventListener {

    @EventListener
    public void handleSessionConnectedEvent(SessionConnectedEvent event) {
        SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(event.getMessage());
        Object user = Objects.requireNonNull(headers.getUser()).getPrincipal();
        System.out.println("User connected: " + user);
    }
}

This WebSocketEventListener class listens for the SessionConnectedEvent. When a session is connected, the handleSessionConnectedEvent method is invoked. Inside this method, we extract the user principal from the headers using SimpMessageHeaderAccessor. The key to accessing the authenticated user is using headers.getUser(), which retrieves the Authentication object that was set in the ChannelInterceptor. This allows you to access user details such as the username or any other custom attributes you included in the Authentication object.

Extracting User Information

Once you have the user principal, you can extract the necessary information, such as the username, and use it for your application logic. For example, you might want to store the session ID and the username in a database or a session registry. The following code demonstrates how to extract the username from the principal:

import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;

import java.util.Objects;

@Component
public class WebSocketEventListener {

    @EventListener
    public void handleSessionConnectedEvent(SessionConnectedEvent event) {
        SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(event.getMessage());
        Object user = Objects.requireNonNull(headers.getUser()).getPrincipal();
        if (user instanceof String) {
            String username = (String) user;
            System.out.println("User connected: " + username);
        } else if (user instanceof UsernamePasswordAuthenticationToken) {
            UsernamePasswordAuthenticationToken authenticationToken = (UsernamePasswordAuthenticationToken) user;
            String username = authenticationToken.getName();
            System.out.println("User connected: " + username);
        }
    }
}

In this enhanced version of the event listener, we check the type of the user principal. If it's a String, we cast it directly to a string. If it's a UsernamePasswordAuthenticationToken, we use the getName() method to retrieve the username. This ensures that we correctly extract the username regardless of how the Authentication object was constructed in the ChannelInterceptor. By handling different types of principals, you make your code more robust and adaptable to changes in your authentication implementation.

Storing Session Information

In many applications, it's necessary to store session information, such as the session ID and the authenticated user, for later use. This can be achieved by maintaining a session registry. Here’s an example of how you can store session information in a simple registry:

import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;

import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

@Component
public class WebSocketEventListener {

    private final Map<String, String> sessionUserMap = new ConcurrentHashMap<>();

    @EventListener
    public void handleSessionConnectedEvent(SessionConnectedEvent event) {
        SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(event.getMessage());
        String sessionId = headers.getSessionId();
        Object user = Objects.requireNonNull(headers.getUser()).getPrincipal();
        String username = null;
        if (user instanceof String) {
            username = (String) user;
        } else if (user instanceof UsernamePasswordAuthenticationToken) {
            UsernamePasswordAuthenticationToken authenticationToken = (UsernamePasswordAuthenticationToken) user;
            username = authenticationToken.getName();
        }
        if (sessionId != null && username != null) {
            sessionUserMap.put(sessionId, username);
            System.out.println("User connected: " + username + ", Session ID: " + sessionId);
        }
    }

    @EventListener
    public void handleSessionDisconnectEvent(SessionDisconnectEvent event) {
        String sessionId = event.getSessionId();
        if (sessionId != null) {
            String username = sessionUserMap.remove(sessionId);
            System.out.println("User disconnected: " + username + ", Session ID: " + sessionId);
        }
    }

    public String getUsernameForSession(String sessionId) {
        return sessionUserMap.get(sessionId);
    }
}

In this example, a ConcurrentHashMap is used to store the mapping between session IDs and usernames. The handleSessionConnectedEvent method extracts the session ID and username and stores them in the map. Additionally, a handleSessionDisconnectEvent method is added to remove the session from the map when a user disconnects. This session registry allows you to keep track of active WebSocket sessions and their associated users, which can be useful for various application features such as sending targeted messages or monitoring user activity.

Complete Example

For clarity, here’s a complete example combining all the components:

// JwtTokenProvider.java
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.util.Date;

@Component
public class JwtTokenProvider {

    @Value("${jwt.secret}")
    private String jwtSecret;

    public String getUsernameFromToken(String token) {
        return getClaimsFromToken(token).getSubject();
    }

    public boolean validateToken(String token) {
        try {
            Jwts.parserBuilder().setSigningKey(getSigningKey()).build().parseClaimsJws(token);
            return true;
        } catch (Exception e) {
            return false;
        }
    }

    private Claims getClaimsFromToken(String token) {
        return Jwts.parserBuilder().setSigningKey(getSigningKey()).build().parseClaimsJws(token).getBody();
    }

    private SecretKey getSigningKey() {
        byte[] keyBytes = jwtSecret.getBytes(StandardCharsets.UTF_8);
        return Keys.hmacShaKeyFor(keyBytes);
    }
}

// JwtChannelInterceptor.java
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.stereotype.Component;

import java.util.List;
import java.util.Optional;

@Component
public class JwtChannelInterceptor implements ChannelInterceptor {

    private final JwtTokenProvider jwtTokenProvider;

    public JwtChannelInterceptor(JwtTokenProvider jwtTokenProvider) {
        this.jwtTokenProvider = jwtTokenProvider;
    }

    @Override
    public Message<?> preSend(Message<?> message, MessageChannel channel) {
        StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message);

        if (StompCommand.CONNECT.equals(accessor.getCommand())) {
            Optional.ofNullable(accessor.getFirstNativeHeader("Authorization"))
                    .map(token -> token.replace("Bearer ", ""))
                    .filter(jwtTokenProvider::validateToken)
                    .ifPresent(token -> {
                        String username = jwtTokenProvider.getUsernameFromToken(token);
                        Authentication authentication = new UsernamePasswordAuthenticationToken(
                                username,
                                null,
                                List.of(new SimpleGrantedAuthority("ROLE_USER"))
                        );
                        accessor.setUser(authentication);
                    });
        }
        return message;
    }
}

// WebSocketConfig.java
import org.springframework.context.annotation.Configuration;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;

@Configuration
@EnableWebSocketMessageBroker
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

    private final JwtChannelInterceptor jwtChannelInterceptor;

    public WebSocketConfig(JwtChannelInterceptor jwtChannelInterceptor) {
        this.jwtChannelInterceptor = jwtChannelInterceptor;
    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry config) {
        config.enableSimpleBroker("/topic");
        config.setApplicationDestinationPrefixes("/app");
    }

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {
        registry.addEndpoint("/ws").withSockJS();
    }

    @Override
    public void configureClientInboundChannel(ChannelRegistration registration) {
        registration.interceptors(jwtChannelInterceptor);
    }
}

// WebSocketEventListener.java
import org.springframework.context.event.EventListener;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.messaging.SessionConnectedEvent;
import org.springframework.web.socket.messaging.SessionDisconnectEvent;

import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

@Component
public class WebSocketEventListener {

    private final Map<String, String> sessionUserMap = new ConcurrentHashMap<>();

    @EventListener
    public void handleSessionConnectedEvent(SessionConnectedEvent event) {
        SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(event.getMessage());
        String sessionId = headers.getSessionId();
        Object user = Objects.requireNonNull(headers.getUser()).getPrincipal();
        String username = null;
        if (user instanceof String) {
            username = (String) user;
        } else if (user instanceof UsernamePasswordAuthenticationToken) {
            UsernamePasswordAuthenticationToken authenticationToken = (UsernamePasswordAuthenticationToken) user;
            username = authenticationToken.getName();
        }
        if (sessionId != null && username != null) {
            sessionUserMap.put(sessionId, username);
            System.out.println("User connected: " + username + ", Session ID: " + sessionId);
        }
    }

    @EventListener
    public void handleSessionDisconnectEvent(SessionDisconnectEvent event) {
        String sessionId = event.getSessionId();
        if (sessionId != null) {
            String username = sessionUserMap.remove(sessionId);
            System.out.println("User disconnected: " + username + ", Session ID: " + sessionId);
        }
    }

    public String getUsernameForSession(String sessionId) {
        return sessionUserMap.get(sessionId);
    }
}

This complete example demonstrates how to set up JWT authentication for Spring WebSockets, access the authenticated user principal in a SessionConnectedEvent, and store session information. By following this comprehensive guide, you can effectively secure your WebSocket applications and manage user sessions.

Conclusion

Accessing the authenticated user principal in a SessionConnectedEvent when using a ChannelInterceptor for JWT authentication requires careful handling of the Authentication object. By setting the Authentication in the StompHeaderAccessor within the ChannelInterceptor and then retrieving it in the SessionConnectedEvent handler, you can effectively propagate the authentication information. This article has provided a detailed guide on how to achieve this, including setting up a custom ChannelInterceptor, configuring the WebSocket message broker, and implementing an event listener to access the user principal. The techniques discussed here are essential for building secure and robust real-time applications with Spring WebSockets.

By following the steps outlined in this article, you can confidently implement JWT authentication in your WebSocket applications and ensure that user information is readily available throughout the session lifecycle. This approach not only enhances security but also provides a solid foundation for building advanced features that require user context, such as personalized messaging and session management. The best practices discussed in this guide will help you create scalable and secure WebSocket solutions.