Skip to content

Commit

Permalink
Propagate SecurityContext in ChannelInterceptor
Browse files Browse the repository at this point in the history
Add `SecurityContextPropagationChannelInterceptor` that
propagates the current security context through the
Spring Messaging API.

Namely, it adds the current security context into any
message before it is sent and then populates the security
context when that message is received, typically in a
separate thread.
  • Loading branch information
artembilan authored and jzheaux committed Oct 16, 2023
1 parent 817e9d6 commit 60a00bb
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.messaging.context;

import java.util.Stack;

import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.ExecutorChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.util.Assert;

/**
* An {@link ExecutorChannelInterceptor} that takes an {@link Authentication} from the
* current {@link SecurityContext} (if any) in the
* {@link #preSend(Message, MessageChannel)} callback and stores it into an
* {@link #authenticationHeaderName} message header. Then sets the context from this
* header in the {@link #beforeHandle(Message, MessageChannel, MessageHandler)} and
* {@link #postReceive(Message, MessageChannel)} both of which typically happen on a
* different thread.
* <p>
* Note: cannot be used in combination with a {@link SecurityContextChannelInterceptor} on
* the same channel since both these interceptors modify a security context on a handling
* and receiving operations.
*
* @author Artem Bilan
* @since 6.2
* @see SecurityContextChannelInterceptor
*/
public final class SecurityContextPropagationChannelInterceptor implements ExecutorChannelInterceptor {

private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();

private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext();

private final String authenticationHeaderName;

private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous",
AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));

/**
* Create a new instance using the header of the name
* {@link SimpMessageHeaderAccessor#USER_HEADER}.
*/
public SecurityContextPropagationChannelInterceptor() {
this(SimpMessageHeaderAccessor.USER_HEADER);
}

/**
* Create a new instance that uses the specified header to populate the
* {@link Authentication}.
* @param authenticationHeaderName the header name to populate the
* {@link Authentication}. Cannot be null.
*/
public SecurityContextPropagationChannelInterceptor(String authenticationHeaderName) {
Assert.notNull(authenticationHeaderName, "authenticationHeaderName cannot be null");
this.authenticationHeaderName = authenticationHeaderName;
}

public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
this.securityContextHolderStrategy = strategy;
this.empty = this.securityContextHolderStrategy.createEmptyContext();
}

/**
* Configure an Authentication used for anonymous authentication. Default is: <pre>
* new AnonymousAuthenticationToken(&quot;key&quot;, &quot;anonymous&quot;,
* AuthorityUtils.createAuthorityList(&quot;ROLE_ANONYMOUS&quot;));
* </pre>
* @param authentication the Authentication used for anonymous authentication. Cannot
* be null.
*/
public void setAnonymousAuthentication(Authentication authentication) {
Assert.notNull(authentication, "authentication cannot be null");
this.anonymous = authentication;
}

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
if (authentication == null) {
authentication = this.anonymous;
}
return MessageBuilder.fromMessage(message).setHeader(this.authenticationHeaderName, authentication).build();
}

@Override
public Message<?> beforeHandle(Message<?> message, MessageChannel channel, MessageHandler handler) {
return postReceive(message, channel);
}

@Override
public Message<?> postReceive(Message<?> message, MessageChannel channel) {
setup(message);
return message;
}

@Override
public void afterMessageHandled(Message<?> message, MessageChannel channel, MessageHandler handler, Exception ex) {
cleanup();
}

private void setup(Message<?> message) {
Authentication authentication = message.getHeaders().get(this.authenticationHeaderName, Authentication.class);
SecurityContext currentContext = this.securityContextHolderStrategy.getContext();
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null) {
contextStack = new Stack<>();
originalContext.set(contextStack);
}
contextStack.push(currentContext);
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
context.setAuthentication(authentication);
this.securityContextHolderStrategy.setContext(context);
}

private void cleanup() {
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null || contextStack.isEmpty()) {
this.securityContextHolderStrategy.clearContext();
originalContext.remove();
return;
}
SecurityContext context = contextStack.pop();
try {
if (this.empty.equals(context)) {
this.securityContextHolderStrategy.clearContext();
originalContext.remove();
}
else {
this.securityContextHolderStrategy.setContext(context);
}
}
catch (Throwable ex) {
this.securityContextHolderStrategy.clearContext();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.messaging.context;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

@ExtendWith(MockitoExtension.class)
public class SecurityContextPropagationChannelInterceptorTests {

@Mock
MessageChannel channel;

@Mock
MessageHandler handler;

MessageBuilder<String> messageBuilder;

Authentication authentication;

SecurityContextPropagationChannelInterceptor interceptor;

@BeforeEach
public void setup() {
this.authentication = new TestingAuthenticationToken("user", "pass", "ROLE_USER");
this.messageBuilder = MessageBuilder.withPayload("payload");
this.interceptor = new SecurityContextPropagationChannelInterceptor();
}

@AfterEach
public void cleanup() {
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
SecurityContextHolder.clearContext();
}

@Test
public void preSendDefaultHeader() {
SecurityContextHolder.getContext().setAuthentication(this.authentication);
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel);
assertThat(message.getHeaders()).containsEntry(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
}

@Test
public void preSendCustomHeader() {
SecurityContextHolder.getContext().setAuthentication(this.authentication);
String headerName = "header";
this.interceptor = new SecurityContextPropagationChannelInterceptor(headerName);
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel);
assertThat(message.getHeaders()).containsEntry(headerName, this.authentication);
}

@Test
public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() {
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
strategy.setContext(new SecurityContextImpl(this.authentication));
this.interceptor.setSecurityContextHolderStrategy(strategy);
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel);
this.interceptor.beforeHandle(message, this.channel, this.handler);
verify(strategy, times(2)).getContext();
assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
}

@Test
public void preSendUserNoContext() {
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel);
assertThat(message.getHeaders()).containsKey(SimpMessageHeaderAccessor.USER_HEADER);
assertThat(message.getHeaders().get(SimpMessageHeaderAccessor.USER_HEADER))
.isInstanceOf(AnonymousAuthenticationToken.class);
}

@Test
public void beforeHandleUserSet() {
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
}

@Test
public void postReceiveUserSet() {
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.postReceive(this.messageBuilder.build(), this.channel);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
}

@Test
public void authenticationIsPropagatedFromPreSendToPostReceive() {
SecurityContextHolder.getContext().setAuthentication(this.authentication);
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel);
assertThat(message.getHeaders().get(SimpMessageHeaderAccessor.USER_HEADER)).isSameAs(this.authentication);
this.interceptor.postReceive(message, this.channel);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
}

@Test
public void beforeHandleUserNotSet() {
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
}

@Test
public void afterMessageHandledUserNotSet() {
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
}

@Test
public void afterMessageHandled() {
SecurityContextHolder.getContext().setAuthentication(this.authentication);
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
}

@Test
public void restoresOriginalContext() {
TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER");
SecurityContextHolder.getContext().setAuthentication(original);
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original);
}

}

0 comments on commit 60a00bb

Please sign in to comment.