spring WebSocketHandshakeTests 源码

  • 2022-08-08
  • 浏览 (550)

spring WebSocketHandshakeTests 代码

文件路径:/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java

/*
 * Copyright 2002-2022 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.socket;

import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.junit.jupiter.api.TestInfo;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;

import static org.assertj.core.api.Assertions.assertThat;

/**
 * Client and server-side WebSocket integration tests.
 *
 * @author Rossen Stoyanchev
 * @author Juergen Hoeller
 * @author Sam Brannen
 */
class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {

	@Override
	protected Class<?>[] getAnnotatedConfigClasses() {
		return new Class<?>[] {TestConfig.class};
	}


	@ParameterizedWebSocketTest
	@SuppressWarnings("deprecation")
	void subProtocolNegotiation(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
		super.setup(server, webSocketClient, testInfo);

		WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
		headers.setSecWebSocketProtocol("foo");
		URI url = new URI(getWsBaseUrl() + "/ws");
		WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get();
		assertThat(session.getAcceptedProtocol()).isEqualTo("foo");
		session.close();
	}

	@ParameterizedWebSocketTest  // SPR-12727
	@SuppressWarnings("deprecation")
	void unsolicitedPongWithEmptyPayload(WebSocketTestServer server, WebSocketClient webSocketClient, TestInfo testInfo) throws Exception {
		super.setup(server, webSocketClient, testInfo);

		String url = getWsBaseUrl() + "/ws";
		WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get();

		TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
		serverHandler.setWaitMessageCount(1);

		session.sendMessage(new PongMessage());

		serverHandler.await();
		assertThat(serverHandler.getTransportError()).isNull();
		assertThat(serverHandler.getReceivedMessages().size()).isEqualTo(1);
		assertThat(serverHandler.getReceivedMessages().get(0).getClass()).isEqualTo(PongMessage.class);
	}


	@Configuration
	@EnableWebSocket
	static class TestConfig implements WebSocketConfigurer {

		@Autowired
		private DefaultHandshakeHandler handshakeHandler;  // can't rely on classpath for server detection

		@Override
		public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
			this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz");
			registry.addHandler(handler(), "/ws").setHandshakeHandler(this.handshakeHandler);
		}

		@Bean
		public TestWebSocketHandler handler() {
			return new TestWebSocketHandler();
		}

	}

	@SuppressWarnings("rawtypes")
	private static class TestWebSocketHandler extends AbstractWebSocketHandler {

		private List<WebSocketMessage> receivedMessages = new ArrayList<>();

		private int waitMessageCount;

		private final CountDownLatch latch = new CountDownLatch(1);

		private Throwable transportError;

		public void setWaitMessageCount(int waitMessageCount) {
			this.waitMessageCount = waitMessageCount;
		}

		public List<WebSocketMessage> getReceivedMessages() {
			return this.receivedMessages;
		}

		public Throwable getTransportError() {
			return this.transportError;
		}

		@Override
		public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
			this.receivedMessages.add(message);
			if (this.receivedMessages.size() >= this.waitMessageCount) {
				this.latch.countDown();
			}
		}

		@Override
		public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
			this.transportError = exception;
			this.latch.countDown();
		}

		public void await() throws InterruptedException {
			this.latch.await(5, TimeUnit.SECONDS);
		}
	}

}

相关信息

spring 源码目录

相关文章

spring AbstractHttpRequestTests 源码

spring AbstractWebSocketIntegrationTests 源码

spring ContextLoaderTestUtils 源码

spring JettyWebSocketTestServer 源码

spring TextMessageTests 源码

spring TomcatWebSocketTestServer 源码

spring UndertowTestServer 源码

spring WebSocketExtensionTests 源码

spring WebSocketTestServer 源码

0  赞