XmppWebSocketTransportModule.java

/**
 *
 * Copyright 2020 Aditya Borikar, 2020-2023 Florian Schmaus
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.jivesoftware.smack.websocket;

import java.util.Collections;
import java.util.List;
import java.util.Queue;

import javax.net.ssl.SSLSession;

import org.jivesoftware.smack.AsyncButOrdered;
import org.jivesoftware.smack.SmackException;
import org.jivesoftware.smack.SmackException.NoResponseException;
import org.jivesoftware.smack.SmackException.NotConnectedException;
import org.jivesoftware.smack.SmackFuture;
import org.jivesoftware.smack.SmackFuture.InternalSmackFuture;
import org.jivesoftware.smack.XMPPException;
import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnection.ConnectedButUnauthenticatedStateDescriptor;
import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnection.LookupRemoteConnectionEndpointsStateDescriptor;
import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnectionConfiguration;
import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnectionModule;
import org.jivesoftware.smack.c2s.StreamOpenAndCloseFactory;
import org.jivesoftware.smack.c2s.XmppClientToServerTransport;
import org.jivesoftware.smack.c2s.internal.ModularXmppClientToServerConnectionInternal;
import org.jivesoftware.smack.c2s.internal.WalkStateGraphContext;
import org.jivesoftware.smack.fsm.State;
import org.jivesoftware.smack.fsm.StateDescriptor;
import org.jivesoftware.smack.fsm.StateTransitionResult;
import org.jivesoftware.smack.fsm.StateTransitionResult.AttemptResult;
import org.jivesoftware.smack.packet.AbstractStreamClose;
import org.jivesoftware.smack.packet.AbstractStreamOpen;
import org.jivesoftware.smack.packet.TopLevelStreamElement;
import org.jivesoftware.smack.util.StringUtils;
import org.jivesoftware.smack.util.rce.RemoteConnectionEndpointLookupFailure;
import org.jivesoftware.smack.websocket.XmppWebSocketTransportModule.XmppWebSocketTransport.DiscoveredWebSocketEndpoints;
import org.jivesoftware.smack.websocket.elements.WebSocketCloseElement;
import org.jivesoftware.smack.websocket.elements.WebSocketOpenElement;
import org.jivesoftware.smack.websocket.impl.AbstractWebSocket;
import org.jivesoftware.smack.websocket.impl.WebSocketFactory;
import org.jivesoftware.smack.websocket.impl.WebSocketFactoryService;
import org.jivesoftware.smack.websocket.rce.InsecureWebSocketRemoteConnectionEndpoint;
import org.jivesoftware.smack.websocket.rce.SecureWebSocketRemoteConnectionEndpoint;
import org.jivesoftware.smack.websocket.rce.WebSocketRemoteConnectionEndpoint;
import org.jivesoftware.smack.websocket.rce.WebSocketRemoteConnectionEndpointLookup;
import org.jivesoftware.smack.websocket.rce.WebSocketRemoteConnectionEndpointLookup.Result;

import org.jxmpp.jid.DomainBareJid;

/**
 * The websocket transport module that goes with Smack's modular architecture.
 */
public final class XmppWebSocketTransportModule
                extends ModularXmppClientToServerConnectionModule<XmppWebSocketTransportModuleDescriptor> {

    private static final int WEBSOCKET_NORMAL_CLOSURE = 1000;

    private static final int[] KNOWN_SECURE_PORTS = new int[] {5443, 5281, 7443};
    private static final int[] KNOWN_INSECURE_PORTS = new int[] {5443, 5280, 7070};

    private final XmppWebSocketTransport websocketTransport;

    private AbstractWebSocket websocket;

    XmppWebSocketTransportModule(XmppWebSocketTransportModuleDescriptor moduleDescriptor,
                    ModularXmppClientToServerConnectionInternal connectionInternal) {
        super(moduleDescriptor, connectionInternal);

        websocketTransport = new XmppWebSocketTransport(connectionInternal);
    }

    @Override
    protected XmppWebSocketTransport getTransport() {
        return websocketTransport;
    }

    static final class EstablishingWebSocketConnectionStateDescriptor extends StateDescriptor {
        private EstablishingWebSocketConnectionStateDescriptor() {
            super(XmppWebSocketTransportModule.EstablishingWebSocketConnectionState.class);
            addPredeccessor(LookupRemoteConnectionEndpointsStateDescriptor.class);
            addSuccessor(ConnectedButUnauthenticatedStateDescriptor.class);

            // This states preference to TCP transports over this WebSocket transport implementation.
            declareInferiorityTo("org.jivesoftware.smack.tcp.XmppTcpTransportModule$EstablishingTcpConnectionStateDescriptor");
        }

        @Override
        protected State constructState(ModularXmppClientToServerConnectionInternal connectionInternal) {
            XmppWebSocketTransportModule websocketTransportModule = connectionInternal.connection.getConnectionModuleFor(
                            XmppWebSocketTransportModuleDescriptor.class);
            return websocketTransportModule.constructEstablishingWebSocketConnectionState(this, connectionInternal);
        }
    }

    final class EstablishingWebSocketConnectionState extends State.AbstractTransport {
        EstablishingWebSocketConnectionState(StateDescriptor stateDescriptor,
                        ModularXmppClientToServerConnectionInternal connectionInternal) {
            super(websocketTransport, stateDescriptor, connectionInternal);
        }

        @Override
        public AttemptResult transitionInto(WalkStateGraphContext walkStateGraphContext) throws InterruptedException,
                        NoResponseException, NotConnectedException, SmackException, XMPPException {
            final WebSocketFactory webSocketFactory;
            if (moduleDescriptor.webSocketFactory != null) {
                webSocketFactory = moduleDescriptor.webSocketFactory;
            } else {
                webSocketFactory = WebSocketFactoryService::createWebSocket;
            }

            WebSocketConnectionAttemptState connectionAttemptState = new WebSocketConnectionAttemptState(
                            connectionInternal, discoveredWebSocketEndpoints, webSocketFactory);

            StateTransitionResult.Failure failure = connectionAttemptState.establishWebSocketConnection();
            if (failure != null) {
                return failure;
            }

            websocket = connectionAttemptState.getConnectedWebSocket();

            connectionInternal.setTransport(websocketTransport);

            // TODO: It appears this should be done in a generic way. I'd assume we always
            // have to wait for stream features after the connection was established. But I
            // am not yet 100% positive that this is the case for every transport. Hence keep it here for now(?).
            // See also similar comment in XmppTcpTransportModule.
            // Maybe move this into ConnectedButUnauthenticated state's transitionInto() method? That seems to be the
            // right place.
            connectionInternal.newStreamOpenWaitForFeaturesSequence("stream features after initial connection");

            // Construct a WebSocketConnectedResult using the connected endpoint.
            return new WebSocketConnectedResult(websocket.getEndpoint());
        }
    }

    public EstablishingWebSocketConnectionState constructEstablishingWebSocketConnectionState(
                    EstablishingWebSocketConnectionStateDescriptor establishingWebSocketConnectionStateDescriptor,
                    ModularXmppClientToServerConnectionInternal connectionInternal) {
        return new EstablishingWebSocketConnectionState(establishingWebSocketConnectionStateDescriptor,
                        connectionInternal);
    }

    public static final class WebSocketConnectedResult extends StateTransitionResult.Success {
        final WebSocketRemoteConnectionEndpoint connectedEndpoint;

        public WebSocketConnectedResult(WebSocketRemoteConnectionEndpoint connectedEndpoint) {
            super("WebSocket connection establised with endpoint: " + connectedEndpoint);
            this.connectedEndpoint = connectedEndpoint;
        }
    }

    private DiscoveredWebSocketEndpoints discoveredWebSocketEndpoints;

    /**
     * Transport class for {@link ModularXmppClientToServerConnectionModule}'s websocket implementation.
     */
    public final class XmppWebSocketTransport extends XmppClientToServerTransport {

        AsyncButOrdered<Queue<TopLevelStreamElement>> asyncButOrderedOutgoingElementsQueue;

        XmppWebSocketTransport(ModularXmppClientToServerConnectionInternal connectionInternal) {
            super(connectionInternal);
            asyncButOrderedOutgoingElementsQueue = new AsyncButOrdered<Queue<TopLevelStreamElement>>();
        }

        @Override
        protected void resetDiscoveredConnectionEndpoints() {
            discoveredWebSocketEndpoints = null;
        }

        @Override
        public boolean hasUseableConnectionEndpoints() {
            return discoveredWebSocketEndpoints != null;
        }

        @SuppressWarnings("incomplete-switch")
        @Override
        protected List<SmackFuture<LookupConnectionEndpointsResult, Exception>> lookupConnectionEndpoints() {
            // Assert that there are no stale discovered endpoints prior performing the lookup.
            assert discoveredWebSocketEndpoints == null;

            InternalSmackFuture<LookupConnectionEndpointsResult, Exception> websocketEndpointsLookupFuture = new InternalSmackFuture<>();

            connectionInternal.asyncGo(() -> {
                Result result = null;

                ModularXmppClientToServerConnectionConfiguration configuration = connectionInternal.connection.getConfiguration();
                DomainBareJid host = configuration.getXMPPServiceDomain();

                if (moduleDescriptor.isWebSocketEndpointDiscoveryEnabled()) {
                    // Fetch remote endpoints.
                    result = WebSocketRemoteConnectionEndpointLookup.lookup(host);
                }

                WebSocketRemoteConnectionEndpoint providedEndpoint = moduleDescriptor.getExplicitlyProvidedEndpoint();
                if (providedEndpoint != null) {
                    // If there was not automatic lookup that produced a result, then create a result now.
                    if (result == null) {
                        result = new Result();
                    }

                    // We insert the provided endpoint at the beginning of the list, so that it is used first.
                    final int INSERT_INDEX = 0;
                    if (providedEndpoint instanceof SecureWebSocketRemoteConnectionEndpoint) {
                        SecureWebSocketRemoteConnectionEndpoint secureEndpoint = (SecureWebSocketRemoteConnectionEndpoint) providedEndpoint;
                        result.discoveredSecureEndpoints.add(INSERT_INDEX, secureEndpoint);
                    } else if (providedEndpoint instanceof InsecureWebSocketRemoteConnectionEndpoint) {
                        InsecureWebSocketRemoteConnectionEndpoint insecureEndpoint = (InsecureWebSocketRemoteConnectionEndpoint) providedEndpoint;
                        result.discoveredInsecureEndpoints.add(INSERT_INDEX, insecureEndpoint);
                    } else {
                        throw new AssertionError();
                    }
                }

                if (moduleDescriptor.isImplicitWebSocketEndpointEnabled()) {
                    for (final int securePort : KNOWN_SECURE_PORTS) {
                        SecureWebSocketRemoteConnectionEndpoint implicitSecureEndpoint = SecureWebSocketRemoteConnectionEndpoint.from(
                            WebSocketRemoteConnectionEndpoint.SECURE_WEB_SOCKET_SCHEME + "://" + host + ":" + securePort + "/ws");
                        result.discoveredSecureEndpoints.add(implicitSecureEndpoint);
                    }
                    for (final int insecurePort : KNOWN_INSECURE_PORTS) {
                        InsecureWebSocketRemoteConnectionEndpoint implicitInsecureEndpoint = InsecureWebSocketRemoteConnectionEndpoint.from(
                            WebSocketRemoteConnectionEndpoint.INSECURE_WEB_SOCKET_SCHEME + "://" + host + ":" + insecurePort + "/ws");
                        result.discoveredInsecureEndpoints.add(implicitInsecureEndpoint);
                    }
                }

                final LookupConnectionEndpointsResult endpointsResult;
                if (result.isEmpty()) {
                    endpointsResult = new WebSocketEndpointsDiscoveryFailed(result.lookupFailures);
                } else {
                    endpointsResult = new DiscoveredWebSocketEndpoints(result);
                }

                websocketEndpointsLookupFuture.setResult(endpointsResult);
            });

            return Collections.singletonList(websocketEndpointsLookupFuture);
        }

        @Override
        protected void loadConnectionEndpoints(LookupConnectionEndpointsSuccess lookupConnectionEndpointsSuccess) {
            discoveredWebSocketEndpoints = (DiscoveredWebSocketEndpoints) lookupConnectionEndpointsSuccess;
        }

        @Override
        protected void afterFiltersClosed() {
        }

        @Override
        protected void disconnect() {
            websocket.disconnect(WEBSOCKET_NORMAL_CLOSURE, "WebSocket closed normally");
        }

        @Override
        protected void notifyAboutNewOutgoingElements() {
            final Queue<TopLevelStreamElement> outgoingElementsQueue = connectionInternal.outgoingElementsQueue;
            asyncButOrderedOutgoingElementsQueue.performAsyncButOrdered(outgoingElementsQueue, () -> {
                for (TopLevelStreamElement topLevelStreamElement; (topLevelStreamElement = outgoingElementsQueue.poll()) != null;) {
                    websocket.send(topLevelStreamElement);
                    connectionInternal.fireFirstLevelElementSendListeners(topLevelStreamElement);
                }
            });
        }

        @Override
        public SSLSession getSslSession() {
           return websocket.getSSLSession();
        }

        @Override
        public boolean isTransportSecured() {
            return websocket.isConnectionSecure();
        }

        @Override
        public Stats getStats() {
            return null;
        }

        @Override
        public StreamOpenAndCloseFactory getStreamOpenAndCloseFactory() {
            // TODO: Create extra class for this?
            return new StreamOpenAndCloseFactory() {
                @Override
                public AbstractStreamOpen createStreamOpen(DomainBareJid to, CharSequence from, String id, String lang) {
                    return new WebSocketOpenElement(to);
                }
                @Override
                public AbstractStreamClose createStreamClose() {
                    return new WebSocketCloseElement();
                }
            };
        }

        /**
         * Contains {@link Result} for successfully discovered endpoints.
         */
        public final class DiscoveredWebSocketEndpoints implements LookupConnectionEndpointsSuccess {
            final WebSocketRemoteConnectionEndpointLookup.Result result;

            DiscoveredWebSocketEndpoints(Result result) {
                assert result != null;
                this.result = result;
            }
        }

        /**
         * Contains list of {@link RemoteConnectionEndpointLookupFailure} when no endpoint
         * could be found during http lookup.
         */
        final class WebSocketEndpointsDiscoveryFailed implements LookupConnectionEndpointsFailed {
            final List<RemoteConnectionEndpointLookupFailure> lookupFailures;

            WebSocketEndpointsDiscoveryFailed(RemoteConnectionEndpointLookupFailure lookupFailure) {
                this(Collections.singletonList(lookupFailure));
            }

            WebSocketEndpointsDiscoveryFailed(List<RemoteConnectionEndpointLookupFailure> lookupFailures) {
                assert lookupFailures != null;
                this.lookupFailures = Collections.unmodifiableList(lookupFailures);
            }

            @Override
            public String toString() {
                StringBuilder str = new StringBuilder();
                StringUtils.appendTo(lookupFailures, str);
                return str.toString();
            }
        }
    }
}