spring AbstractTyrusRequestUpgradeStrategy 源码

  • 2022-08-08
  • 浏览 (420)

spring AbstractTyrusRequestUpgradeStrategy 代码

文件路径:/spring-websocket/src/main/java/org/springframework/web/socket/server/standard/AbstractTyrusRequestUpgradeStrategy.java

/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.socket.server.standard;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.websocket.DeploymentException;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Extension;
import jakarta.websocket.WebSocketContainer;
import org.glassfish.tyrus.core.ComponentProviderService;
import org.glassfish.tyrus.core.RequestContext;
import org.glassfish.tyrus.core.TyrusEndpointWrapper;
import org.glassfish.tyrus.core.TyrusUpgradeResponse;
import org.glassfish.tyrus.core.TyrusWebSocketEngine;
import org.glassfish.tyrus.core.Version;
import org.glassfish.tyrus.server.TyrusServerContainer;
import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;

import org.springframework.beans.DirectFieldAccessor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.server.HandshakeFailureException;

import static org.glassfish.tyrus.spi.WebSocketEngine.UpgradeStatus.SUCCESS;

/**
 * A base class for {@code RequestUpgradeStrategy} implementations on top of
 * JSR-356 based servers which include Tyrus as their WebSocket engine.
 *
 * <p>Works with Tyrus 1.11 (WebLogic 12.2.1) and Tyrus 1.12 (GlassFish 4.1.1).
 *
 * @author Rossen Stoyanchev
 * @author Brian Clozel
 * @author Juergen Hoeller
 * @since 4.1
 * @see <a href="https://tyrus.java.net/">Project Tyrus</a>
 */
public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {

	private static final Random random = new Random();

	private static final Constructor<?> constructor;

	private static boolean constructorWithBooleanArgument;

	private static final Method registerMethod;

	private static final Method unRegisterMethod;

	static {
		try {
			constructor = getEndpointConstructor();
			int parameterCount = constructor.getParameterCount();
			constructorWithBooleanArgument = (parameterCount == 10);
			if (!constructorWithBooleanArgument && parameterCount != 9) {
				throw new IllegalStateException("Expected TyrusEndpointWrapper constructor with 9 or 10 arguments");
			}
			registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class);
			unRegisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class);
			ReflectionUtils.makeAccessible(registerMethod);
		}
		catch (Exception ex) {
			throw new IllegalStateException("No compatible Tyrus version found", ex);
		}
	}

	private static Constructor<?> getEndpointConstructor() {
		for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) {
			Class<?>[] types = current.getParameterTypes();
			if (Endpoint.class == types[0] && EndpointConfig.class == types[1]) {
				return current;
			}
		}
		throw new IllegalStateException("No compatible Tyrus version found");
	}


	private final ComponentProviderService componentProvider = ComponentProviderService.create();


	@Override
	public String[] getSupportedVersions() {
		return StringUtils.tokenizeToStringArray(Version.getSupportedWireProtocolVersions(), ",");
	}

	@Override
	protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
		try {
			return super.getInstalledExtensions(container);
		}
		catch (UnsupportedOperationException ex) {
			return new ArrayList<>(0);
		}
	}

	@Override
	public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
			@Nullable String selectedProtocol, List<Extension> extensions, Endpoint endpoint)
			throws HandshakeFailureException {

		HttpServletRequest servletRequest = getHttpServletRequest(request);
		HttpServletResponse servletResponse = getHttpServletResponse(response);

		TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest);
		TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine();
		Object tyrusEndpoint = null;
		boolean success;

		try {
			// Shouldn't matter for processing but must be unique
			String path = "/" + random.nextLong();
			tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine);
			register(engine, tyrusEndpoint);

			HttpHeaders headers = request.getHeaders();
			RequestContext requestContext = createRequestContext(servletRequest, path, headers);
			TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
			UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
			success = SUCCESS.equals(upgradeInfo.getStatus());
			if (success) {
				if (logger.isTraceEnabled()) {
					logger.trace("Successful request upgrade: " + upgradeResponse.getHeaders());
				}
				handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse);
			}
		}
		catch (Exception ex) {
			unregisterTyrusEndpoint(engine, tyrusEndpoint);
			throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex);
		}

		unregisterTyrusEndpoint(engine, tyrusEndpoint);
		if (!success) {
			throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI());
		}
	}

	private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, @Nullable String protocol,
			List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine)
			throws DeploymentException {

		ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint);
		endpointConfig.setSubprotocols(Collections.singletonList(protocol));
		endpointConfig.setExtensions(extensions);
		return createEndpoint(endpointConfig, this.componentProvider, container, engine);
	}

	private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) {
		RequestContext context =
				RequestContext.Builder.create()
						.requestURI(URI.create(endpointPath))
						.userPrincipal(request.getUserPrincipal())
						.secure(request.isSecure())
						.remoteAddr(request.getRemoteAddr())
						.build();
		headers.forEach((header, value) -> context.getHeaders().put(header, value));
		return context;
	}

	private void unregisterTyrusEndpoint(TyrusWebSocketEngine engine, @Nullable Object tyrusEndpoint) {
		if (tyrusEndpoint != null) {
			try {
				unregister(engine, tyrusEndpoint);
			}
			catch (Throwable ex) {
				// ignore
			}
		}
	}

	private Object createEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
			WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {

		DirectFieldAccessor accessor = new DirectFieldAccessor(engine);
		Object sessionListener = accessor.getPropertyValue("sessionListener");
		Object clusterContext = accessor.getPropertyValue("clusterContext");
		try {
			if (constructorWithBooleanArgument) {
				// Tyrus 1.11+
				return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
						"/", registration.getConfigurator(), sessionListener, clusterContext, null, Boolean.TRUE);
			}
			else {
				return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
						"/", registration.getConfigurator(), sessionListener, clusterContext, null);
			}
		}
		catch (Exception ex) {
			throw new HandshakeFailureException("Failed to register " + registration, ex);
		}
	}

	private void register(TyrusWebSocketEngine engine, Object endpoint) {
		try {
			registerMethod.invoke(engine, endpoint);
		}
		catch (Exception ex) {
			throw new HandshakeFailureException("Failed to register " + endpoint, ex);
		}
	}

	private void unregister(TyrusWebSocketEngine engine, Object endpoint) {
		try {
			unRegisterMethod.invoke(engine, endpoint);
		}
		catch (Exception ex) {
			throw new HandshakeFailureException("Failed to unregister " + endpoint, ex);
		}
	}


	protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response,
			UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException;

}

相关信息

spring 源码目录

相关文章

spring AbstractStandardUpgradeStrategy 源码

spring GlassFishRequestUpgradeStrategy 源码

spring ServerEndpointExporter 源码

spring ServerEndpointRegistration 源码

spring ServletServerContainerFactoryBean 源码

spring SpringConfigurator 源码

spring TomcatRequestUpgradeStrategy 源码

spring UndertowRequestUpgradeStrategy 源码

spring WebLogicRequestUpgradeStrategy 源码

spring WebSphereRequestUpgradeStrategy 源码

0  赞