001/**
002 *
003 * Copyright 2019-2020 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.tcp;
018
019import java.io.IOException;
020import java.net.InetSocketAddress;
021import java.nio.Buffer;
022import java.nio.ByteBuffer;
023import java.nio.channels.ClosedChannelException;
024import java.nio.channels.SelectableChannel;
025import java.nio.channels.SelectionKey;
026import java.nio.channels.SocketChannel;
027import java.security.cert.CertificateException;
028import java.util.ArrayList;
029import java.util.Collection;
030import java.util.Collections;
031import java.util.IdentityHashMap;
032import java.util.Iterator;
033import java.util.List;
034import java.util.ListIterator;
035import java.util.Map;
036import java.util.Map.Entry;
037import java.util.concurrent.atomic.AtomicInteger;
038import java.util.concurrent.atomic.AtomicLong;
039import java.util.concurrent.locks.ReentrantLock;
040import java.util.logging.Level;
041import java.util.logging.Logger;
042
043import javax.net.ssl.SSLEngine;
044import javax.net.ssl.SSLEngineResult;
045import javax.net.ssl.SSLException;
046import javax.net.ssl.SSLSession;
047
048import org.jivesoftware.smack.ConnectionConfiguration.SecurityMode;
049import org.jivesoftware.smack.SmackException;
050import org.jivesoftware.smack.SmackException.SecurityRequiredByClientException;
051import org.jivesoftware.smack.SmackException.SecurityRequiredByServerException;
052import org.jivesoftware.smack.SmackException.SmackCertificateException;
053import org.jivesoftware.smack.SmackFuture;
054import org.jivesoftware.smack.SmackFuture.InternalSmackFuture;
055import org.jivesoftware.smack.SmackReactor.SelectionKeyAttachment;
056import org.jivesoftware.smack.XMPPException;
057import org.jivesoftware.smack.XmppInputOutputFilter;
058import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnection.ConnectedButUnauthenticatedStateDescriptor;
059import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnection.LookupRemoteConnectionEndpointsStateDescriptor;
060import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnectionModule;
061import org.jivesoftware.smack.c2s.XmppClientToServerTransport;
062import org.jivesoftware.smack.c2s.internal.ModularXmppClientToServerConnectionInternal;
063import org.jivesoftware.smack.c2s.internal.WalkStateGraphContext;
064import org.jivesoftware.smack.debugger.SmackDebugger;
065import org.jivesoftware.smack.fsm.State;
066import org.jivesoftware.smack.fsm.StateDescriptor;
067import org.jivesoftware.smack.fsm.StateTransitionResult;
068import org.jivesoftware.smack.internal.SmackTlsContext;
069import org.jivesoftware.smack.packet.Stanza;
070import org.jivesoftware.smack.packet.StartTls;
071import org.jivesoftware.smack.packet.StreamOpen;
072import org.jivesoftware.smack.packet.TlsFailure;
073import org.jivesoftware.smack.packet.TlsProceed;
074import org.jivesoftware.smack.packet.TopLevelStreamElement;
075import org.jivesoftware.smack.packet.XmlEnvironment;
076import org.jivesoftware.smack.tcp.XmppTcpTransportModule.XmppTcpNioTransport.DiscoveredTcpEndpoints;
077import org.jivesoftware.smack.tcp.rce.RemoteXmppTcpConnectionEndpoints;
078import org.jivesoftware.smack.tcp.rce.RemoteXmppTcpConnectionEndpoints.Result;
079import org.jivesoftware.smack.tcp.rce.Rfc6120TcpRemoteConnectionEndpoint;
080import org.jivesoftware.smack.util.CollectionUtil;
081import org.jivesoftware.smack.util.PacketParserUtils;
082import org.jivesoftware.smack.util.StringUtils;
083import org.jivesoftware.smack.util.UTF8;
084import org.jivesoftware.smack.util.XmlStringBuilder;
085import org.jivesoftware.smack.util.rce.RemoteConnectionEndpointLookupFailure;
086import org.jivesoftware.smack.xml.XmlPullParser;
087import org.jivesoftware.smack.xml.XmlPullParserException;
088
089import org.jxmpp.jid.Jid;
090import org.jxmpp.jid.util.JidUtil;
091import org.jxmpp.xml.splitter.Utf8ByteXmppXmlSplitter;
092import org.jxmpp.xml.splitter.XmlPrettyPrinter;
093import org.jxmpp.xml.splitter.XmlPrinter;
094import org.jxmpp.xml.splitter.XmppElementCallback;
095import org.jxmpp.xml.splitter.XmppXmlSplitter;
096
097public class XmppTcpTransportModule extends ModularXmppClientToServerConnectionModule<XmppTcpTransportModuleDescriptor> {
098
099    private static final Logger LOGGER = Logger.getLogger(XmppTcpTransportModule.class.getName());
100
101    private static final int CALLBACK_MAX_BYTES_READ = 10 * 1024 * 1024;
102    private static final int CALLBACK_MAX_BYTES_WRITEN = CALLBACK_MAX_BYTES_READ;
103
104    private static final int MAX_ELEMENT_SIZE = 64 * 1024;
105
106    private final XmppTcpNioTransport tcpNioTransport;
107
108    private SelectionKey selectionKey;
109    private SelectionKeyAttachment selectionKeyAttachment;
110    private SocketChannel socketChannel;
111    private InetSocketAddress remoteAddress;
112
113    private TlsState tlsState;
114
115    private Iterator<CharSequence> outgoingCharSequenceIterator;
116
117    private final List<TopLevelStreamElement> currentlyOutgoingElements = new ArrayList<>();
118    private final Map<ByteBuffer, List<TopLevelStreamElement>> bufferToElementMap = new IdentityHashMap<>();
119
120    private ByteBuffer outgoingBuffer;
121    private ByteBuffer filteredOutgoingBuffer;
122    private final List<ByteBuffer> networkOutgoingBuffers = new ArrayList<>();
123    private long networkOutgoingBuffersBytes;
124
125    // TODO: Make the size of the incomingBuffer configurable.
126    private final ByteBuffer incomingBuffer = ByteBuffer.allocateDirect(2 * 4096);
127
128    private final ReentrantLock channelSelectedCallbackLock = new ReentrantLock();
129
130    private long totalBytesRead;
131    private long totalBytesWritten;
132    private long totalBytesReadAfterFilter;
133    private long totalBytesWrittenBeforeFilter;
134    private long handledChannelSelectedCallbacks;
135    private long callbackPreemtBecauseBytesWritten;
136    private long callbackPreemtBecauseBytesRead;
137    private int sslEngineDelegatedTasks;
138    private int maxPendingSslEngineDelegatedTasks;
139
140    // TODO: Use LongAdder once Smack's minimum Android API level is 24 or higher.
141    private final AtomicLong setWriteInterestAfterChannelSelectedCallback = new AtomicLong();
142    private final AtomicLong reactorThreadAlreadyRacing = new AtomicLong();
143    private final AtomicLong afterOutgoingElementsQueueModifiedSetInterestOps = new AtomicLong();
144    private final AtomicLong rejectedChannelSelectedCallbacks = new AtomicLong();
145
146    private Jid lastDestinationAddress;
147
148    private boolean pendingInputFilterData;
149    private boolean pendingOutputFilterData;
150
151    private boolean pendingWriteInterestAfterRead;
152
153    /**
154     * Note that this field is effective final, but due to https://stackoverflow.com/q/30360824/194894 we have to declare it non-final.
155     */
156    private Utf8ByteXmppXmlSplitter splitter;
157
158    /**
159     * Note that this field is effective final, but due to https://stackoverflow.com/q/30360824/194894 we have to declare it non-final.
160     */
161    private XmppXmlSplitter outputDebugSplitter;
162
163    private static final Level STREAM_OPEN_CLOSE_DEBUG_LOG_LEVEL = Level.FINER;
164
165    XmppTcpTransportModule(XmppTcpTransportModuleDescriptor moduleDescriptor, ModularXmppClientToServerConnectionInternal connectionInternal) {
166        super(moduleDescriptor, connectionInternal);
167
168        tcpNioTransport = new XmppTcpNioTransport(connectionInternal);
169
170        XmlPrinter incomingDebugPrettyPrinter = null;
171        final SmackDebugger debugger = connectionInternal.smackDebugger;
172        if (debugger != null) {
173            // Incoming stream debugging.
174            incomingDebugPrettyPrinter = XmlPrettyPrinter.builder()
175                    .setPrettyWriter(sb -> debugger.incomingStreamSink(sb))
176                    .build();
177
178            // Outgoing stream debugging.
179            XmlPrinter outgoingDebugPrettyPrinter = XmlPrettyPrinter.builder()
180                    .setPrettyWriter(sb -> debugger.outgoingStreamSink(sb))
181                    .build();
182            outputDebugSplitter = new XmppXmlSplitter(outgoingDebugPrettyPrinter);
183        }
184
185        XmppXmlSplitter xmppXmlSplitter = new XmppXmlSplitter(MAX_ELEMENT_SIZE, xmppElementCallback,
186                incomingDebugPrettyPrinter);
187        splitter = new Utf8ByteXmppXmlSplitter(xmppXmlSplitter);
188    }
189
190    private final XmppElementCallback xmppElementCallback = new XmppElementCallback() {
191        private String streamOpen;
192        private String streamClose;
193
194        @Override
195        public void onCompleteElement(String completeElement) {
196            assert streamOpen != null;
197            assert streamClose != null;
198
199            connectionInternal.withSmackDebugger(debugger -> debugger.onIncomingElementCompleted());
200
201            String wrappedCompleteElement = streamOpen + completeElement + streamClose;
202            connectionInternal.parseAndProcessElement(wrappedCompleteElement);
203        }
204
205
206        @Override
207        public void streamOpened(String prefix, Map<String, String> attributes) {
208            if (LOGGER.isLoggable(STREAM_OPEN_CLOSE_DEBUG_LOG_LEVEL)) {
209                LOGGER.log(STREAM_OPEN_CLOSE_DEBUG_LOG_LEVEL,
210                                "Stream of " + this + " opened. prefix=" + prefix + " attributes=" + attributes);
211            }
212
213            final String prefixXmlns = "xmlns:" + prefix;
214            final StringBuilder streamClose = new StringBuilder(32);
215            final StringBuilder streamOpen = new StringBuilder(256);
216
217            streamOpen.append('<');
218            streamClose.append("</");
219            if (StringUtils.isNotEmpty(prefix)) {
220                streamOpen.append(prefix).append(':');
221                streamClose.append(prefix).append(':');
222            }
223            streamOpen.append("stream");
224            streamClose.append("stream>");
225            for (Entry<String, String> entry : attributes.entrySet()) {
226                String attributeName = entry.getKey();
227                String attributeValue = entry.getValue();
228                switch (attributeName) {
229                case "to":
230                case "from":
231                case "id":
232                case "version":
233                    break;
234                case "xml:lang":
235                    streamOpen.append(" xml:lang='").append(attributeValue).append('\'');
236                    break;
237                case "xmlns":
238                    streamOpen.append(" xmlns='").append(attributeValue).append('\'');
239                    break;
240                default:
241                    if (attributeName.equals(prefixXmlns)) {
242                        streamOpen.append(' ').append(prefixXmlns).append("='").append(attributeValue).append('\'');
243                        break;
244                    }
245                    LOGGER.info("Unknown <stream/> attribute: " + attributeName);
246                    break;
247                }
248            }
249            streamOpen.append('>');
250
251            this.streamOpen = streamOpen.toString();
252            this.streamClose = streamClose.toString();
253
254            XmlPullParser streamOpenParser;
255            try {
256                streamOpenParser = PacketParserUtils.getParserFor(this.streamOpen);
257            } catch (XmlPullParserException | IOException e) {
258                // Should never happen.
259                throw new AssertionError(e);
260            }
261            connectionInternal.onStreamOpen(streamOpenParser);
262        }
263
264        @Override
265        public void streamClosed() {
266            if (LOGGER.isLoggable(STREAM_OPEN_CLOSE_DEBUG_LOG_LEVEL)) {
267                LOGGER.log(STREAM_OPEN_CLOSE_DEBUG_LOG_LEVEL, "Stream of " + this + " closed");
268            }
269
270           connectionInternal.onStreamClosed();
271        }
272    };
273
274    private void onChannelSelected(SelectableChannel selectedChannel, SelectionKey selectedSelectionKey) {
275        assert selectionKey == null || selectionKey == selectedSelectionKey;
276        SocketChannel selectedSocketChannel = (SocketChannel) selectedChannel;
277        // We are *always* interested in OP_READ.
278        int newInterestedOps = SelectionKey.OP_READ;
279        boolean newPendingOutputFilterData = false;
280
281        if (!channelSelectedCallbackLock.tryLock()) {
282            rejectedChannelSelectedCallbacks.incrementAndGet();
283            return;
284        }
285
286        handledChannelSelectedCallbacks++;
287
288        long callbackBytesRead = 0;
289        long callbackBytesWritten = 0;
290
291        try {
292            boolean destinationAddressChanged = false;
293            boolean isLastPartOfElement = false;
294            TopLevelStreamElement currentlyOutgonigTopLevelStreamElement = null;
295            StringBuilder outgoingStreamForDebugger = null;
296
297            writeLoop: while (true) {
298                final boolean moreDataAvailable = !isLastPartOfElement || !connectionInternal.outgoingElementsQueue.isEmpty();
299
300                if (filteredOutgoingBuffer != null || !networkOutgoingBuffers.isEmpty()) {
301                    if (filteredOutgoingBuffer != null) {
302                        networkOutgoingBuffers.add(filteredOutgoingBuffer);
303                        networkOutgoingBuffersBytes += filteredOutgoingBuffer.remaining();
304
305                        filteredOutgoingBuffer = null;
306                        if (moreDataAvailable && networkOutgoingBuffersBytes < 8096) {
307                            continue;
308                        }
309                    }
310
311                    ByteBuffer[] output = networkOutgoingBuffers.toArray(new ByteBuffer[networkOutgoingBuffers.size()]);
312                    long bytesWritten;
313                    try {
314                        bytesWritten = selectedSocketChannel.write(output);
315                    } catch (IOException e) {
316                        // We have seen here so far
317                        // - IOException "Broken pipe"
318                        handleReadWriteIoException(e);
319                        break;
320                    }
321
322                    if (bytesWritten == 0) {
323                        newInterestedOps |= SelectionKey.OP_WRITE;
324                        break;
325                    }
326
327                    callbackBytesWritten += bytesWritten;
328
329                    networkOutgoingBuffersBytes -= bytesWritten;
330
331                    List<? extends Buffer> prunedBuffers = pruneBufferList(networkOutgoingBuffers);
332
333                    for (Buffer prunedBuffer : prunedBuffers) {
334                        List<TopLevelStreamElement> sendElements = bufferToElementMap.remove(prunedBuffer);
335                        if (sendElements == null) {
336                            continue;
337                        }
338                        for (TopLevelStreamElement elementJustSend : sendElements) {
339                            connectionInternal.fireFirstLevelElementSendListeners(elementJustSend);
340                        }
341                    }
342
343                    // Prevent one callback from dominating the reactor thread. Break out of the write-loop if we have
344                    // written a certain amount.
345                    if (callbackBytesWritten > CALLBACK_MAX_BYTES_WRITEN) {
346                        newInterestedOps |= SelectionKey.OP_WRITE;
347                        callbackPreemtBecauseBytesWritten++;
348                        break;
349                    }
350                } else if (outgoingBuffer != null || pendingOutputFilterData) {
351                    pendingOutputFilterData = false;
352
353                    if (outgoingBuffer != null) {
354                        totalBytesWrittenBeforeFilter += outgoingBuffer.remaining();
355                        if (isLastPartOfElement) {
356                            assert currentlyOutgonigTopLevelStreamElement != null;
357                            currentlyOutgoingElements.add(currentlyOutgonigTopLevelStreamElement);
358                        }
359                    }
360
361                    ByteBuffer outputFilterInputData = outgoingBuffer;
362                    // We can now null the outgoingBuffer since the filter step will take care of it from now on.
363                    outgoingBuffer = null;
364
365                    for (ListIterator<XmppInputOutputFilter> it = connectionInternal.getXmppInputOutputFilterBeginIterator(); it.hasNext();) {
366                        XmppInputOutputFilter inputOutputFilter = it.next();
367                        XmppInputOutputFilter.OutputResult outputResult;
368                        try {
369                            outputResult = inputOutputFilter.output(outputFilterInputData, isLastPartOfElement,
370                                    destinationAddressChanged, moreDataAvailable);
371                        } catch (IOException e) {
372                            connectionInternal.notifyConnectionError(e);
373                            break writeLoop;
374                        }
375                        newPendingOutputFilterData |= outputResult.pendingFilterData;
376                        outputFilterInputData = outputResult.filteredOutputData;
377                        if (outputFilterInputData != null) {
378                            outputFilterInputData.flip();
379                        }
380                    }
381
382                    // It is ok if outpuFilterInputData is 'null' here, this is expected behavior.
383                    if (outputFilterInputData != null && outputFilterInputData.hasRemaining()) {
384                        filteredOutgoingBuffer = outputFilterInputData;
385                    } else {
386                        filteredOutgoingBuffer = null;
387                    }
388
389                    // If the filters did eventually not produce any output data but if there is
390                    // pending output data then we have a pending write request after read.
391                    if (filteredOutgoingBuffer == null && newPendingOutputFilterData) {
392                        pendingWriteInterestAfterRead = true;
393                    }
394
395                    if (filteredOutgoingBuffer != null && isLastPartOfElement) {
396                        bufferToElementMap.put(filteredOutgoingBuffer, new ArrayList<>(currentlyOutgoingElements));
397                        currentlyOutgoingElements.clear();
398                    }
399
400                    // Reset that the destination address has changed.
401                    if (destinationAddressChanged) {
402                        destinationAddressChanged = false;
403                    }
404                } else if (outgoingCharSequenceIterator != null) {
405                    CharSequence nextCharSequence = outgoingCharSequenceIterator.next();
406                    outgoingBuffer = UTF8.encode(nextCharSequence);
407                    if (!outgoingCharSequenceIterator.hasNext()) {
408                        outgoingCharSequenceIterator = null;
409                        isLastPartOfElement = true;
410                    } else {
411                        isLastPartOfElement = false;
412                    }
413
414                    final SmackDebugger debugger = connectionInternal.smackDebugger;
415                    if (debugger != null) {
416                        if (outgoingStreamForDebugger == null) {
417                            outgoingStreamForDebugger = new StringBuilder();
418                        }
419                        outgoingStreamForDebugger.append(nextCharSequence);
420
421                        if (isLastPartOfElement) {
422                            try {
423                                outputDebugSplitter.append(outgoingStreamForDebugger);
424                            } catch (IOException e) {
425                                throw new AssertionError(e);
426                            }
427                            debugger.onOutgoingElementCompleted();
428                            outgoingStreamForDebugger = null;
429                        }
430                    }
431                } else if (!connectionInternal.outgoingElementsQueue.isEmpty()) {
432                    currentlyOutgonigTopLevelStreamElement = connectionInternal.outgoingElementsQueue.poll();
433                    if (currentlyOutgonigTopLevelStreamElement instanceof Stanza) {
434                        Stanza currentlyOutgoingStanza = (Stanza) currentlyOutgonigTopLevelStreamElement;
435                        Jid currentDestinationAddress = currentlyOutgoingStanza.getTo();
436                        destinationAddressChanged = !JidUtil.equals(lastDestinationAddress, currentDestinationAddress);
437                        lastDestinationAddress = currentDestinationAddress;
438                    }
439                    CharSequence nextCharSequence = currentlyOutgonigTopLevelStreamElement.toXML(StreamOpen.CLIENT_NAMESPACE);
440                    if (nextCharSequence instanceof XmlStringBuilder) {
441                        XmlStringBuilder xmlStringBuilder = (XmlStringBuilder) nextCharSequence;
442                        XmlEnvironment outgoingStreamXmlEnvironment = connectionInternal.getOutgoingStreamXmlEnvironment();
443                        outgoingCharSequenceIterator = xmlStringBuilder.toList(outgoingStreamXmlEnvironment).iterator();
444                    } else {
445                        outgoingCharSequenceIterator = Collections.singletonList(nextCharSequence).iterator();
446                    }
447                    assert outgoingCharSequenceIterator != null;
448                } else {
449                    // There is nothing more to write.
450                    break;
451                }
452            }
453
454            pendingOutputFilterData = newPendingOutputFilterData;
455            if (!pendingWriteInterestAfterRead && pendingOutputFilterData) {
456                newInterestedOps |= SelectionKey.OP_WRITE;
457            }
458
459            readLoop: while (true) {
460                // Prevent one callback from dominating the reactor thread. Break out of the read-loop if we have
461                // read a certain amount.
462                if (callbackBytesRead > CALLBACK_MAX_BYTES_READ) {
463                    callbackPreemtBecauseBytesRead++;
464                    break;
465                }
466
467                int bytesRead;
468                incomingBuffer.clear();
469                try {
470                    bytesRead = selectedSocketChannel.read(incomingBuffer);
471                } catch (IOException e) {
472                    handleReadWriteIoException(e);
473                    return;
474                }
475
476                if (bytesRead < 0) {
477                    LOGGER.finer("NIO read() returned " + bytesRead
478                            + " for " + this + ". This probably means that the TCP connection was terminated.");
479                    // According to the socket channel javadoc section about "asynchronous reads" a socket channel's
480                    // read() may return -1 if the input side of a socket is shut down.
481                     // Note that we do not call notifyConnectionError() here because the connection may be
482                    // cleanly shutdown which would also cause read() to return '-1. I assume that this socket
483                    // will be selected again, on which read() would throw an IOException, which will be catched
484                    // and invoke notifyConnectionError() (see a few lines above).
485                    /*
486                    IOException exception = new IOException("NIO read() returned " + bytesRead);
487                    notifyConnectionError(exception);
488                    */
489                    return;
490                }
491
492                if (!pendingInputFilterData) {
493                    if (bytesRead == 0) {
494                        // Nothing more to read.
495                        break;
496                    }
497                } else {
498                    pendingInputFilterData = false;
499                }
500
501                if (pendingWriteInterestAfterRead) {
502                    // We have successfully read something and someone announced a write interest after a read. It is
503                    // now possible that a filter is now also able to write additional data (for example SSLEngine).
504                    pendingWriteInterestAfterRead = false;
505                    newInterestedOps |= SelectionKey.OP_WRITE;
506                }
507
508                callbackBytesRead += bytesRead;
509
510                ByteBuffer filteredIncomingBuffer = incomingBuffer;
511                for (ListIterator<XmppInputOutputFilter> it = connectionInternal.getXmppInputOutputFilterEndIterator(); it.hasPrevious();) {
512                    filteredIncomingBuffer.flip();
513
514                    ByteBuffer newFilteredIncomingBuffer;
515                    try {
516                        newFilteredIncomingBuffer = it.previous().input(filteredIncomingBuffer);
517                    } catch (IOException e) {
518                        connectionInternal.notifyConnectionError(e);
519                        return;
520                    }
521                    if (newFilteredIncomingBuffer == null) {
522                        break readLoop;
523                    }
524                    filteredIncomingBuffer = newFilteredIncomingBuffer;
525                }
526
527                final int bytesReadAfterFilter = filteredIncomingBuffer.flip().remaining();
528
529                totalBytesReadAfterFilter += bytesReadAfterFilter;
530
531                try {
532                    splitter.write(filteredIncomingBuffer);
533                } catch (IOException e) {
534                    connectionInternal.notifyConnectionError(e);
535                    return;
536                }
537            }
538        } finally {
539            totalBytesWritten += callbackBytesWritten;
540            totalBytesRead += callbackBytesRead;
541
542            channelSelectedCallbackLock.unlock();
543        }
544
545        // Indicate that there is no reactor thread racing towards handling this selection key.
546        final SelectionKeyAttachment selectionKeyAttachment = this.selectionKeyAttachment;
547        if (selectionKeyAttachment != null) {
548            selectionKeyAttachment.resetReactorThreadRacing();
549        }
550
551        // Check the queue again to prevent lost wakeups caused by elements inserted before we
552        // called resetReactorThreadRacing() a few lines above.
553        if (!connectionInternal.outgoingElementsQueue.isEmpty()) {
554            setWriteInterestAfterChannelSelectedCallback.incrementAndGet();
555            newInterestedOps |= SelectionKey.OP_WRITE;
556        }
557
558        connectionInternal.setInterestOps(selectionKey, newInterestedOps);
559    }
560
561    private void handleReadWriteIoException(IOException e) {
562        if (e instanceof ClosedChannelException && !tcpNioTransport.isConnected()) {
563            // The connection is already closed.
564            return;
565        }
566
567       connectionInternal.notifyConnectionError(e);
568    }
569
570    /**
571     * This is the interface between the "lookup remote connection endpoints" state and the "establish TCP connection"
572     * state. The field is indirectly populated by {@link XmppTcpNioTransport#lookupConnectionEndpoints()} and consumed
573     * by {@link ConnectionAttemptState}.
574     */
575    DiscoveredTcpEndpoints discoveredTcpEndpoints;
576
577    final class XmppTcpNioTransport extends XmppClientToServerTransport {
578
579        protected XmppTcpNioTransport(ModularXmppClientToServerConnectionInternal connectionInternal) {
580            super(connectionInternal);
581        }
582
583        @Override
584        protected void resetDiscoveredConnectionEndpoints() {
585            discoveredTcpEndpoints = null;
586        }
587
588        @Override
589        protected List<SmackFuture<LookupConnectionEndpointsResult, Exception>> lookupConnectionEndpoints() {
590            // Assert that there are no stale discovered endpoints prior performing the lookup.
591            assert discoveredTcpEndpoints == null;
592
593            List<SmackFuture<LookupConnectionEndpointsResult, Exception>> futures = new ArrayList<>(2);
594
595            InternalSmackFuture<LookupConnectionEndpointsResult, Exception> tcpEndpointsLookupFuture = new InternalSmackFuture<>();
596            connectionInternal.asyncGo(() -> {
597                Result<Rfc6120TcpRemoteConnectionEndpoint> result = RemoteXmppTcpConnectionEndpoints.lookup(
598                                connectionInternal.connection.getConfiguration());
599
600                LookupConnectionEndpointsResult endpointsResult;
601                if (result.discoveredRemoteConnectionEndpoints.isEmpty()) {
602                    endpointsResult = new TcpEndpointDiscoveryFailed(result);
603                } else {
604                    endpointsResult = new DiscoveredTcpEndpoints(result);
605                }
606                tcpEndpointsLookupFuture.setResult(endpointsResult);
607            });
608            futures.add(tcpEndpointsLookupFuture);
609
610            if (moduleDescriptor.isDirectTlsEnabled()) {
611                // TODO: Implement this.
612                throw new IllegalArgumentException("DirectTLS is not implemented yet");
613            }
614
615            return futures;
616        }
617
618        @Override
619        protected void loadConnectionEndpoints(LookupConnectionEndpointsSuccess lookupConnectionEndpointsSuccess) {
620            // The API contract stats that we will be given the instance we handed out with lookupConnectionEndpoints,
621            // which must be of type DiscoveredTcpEndpoints here. Hence if we can not cast it, then there is an internal
622            // Smack error.
623            discoveredTcpEndpoints = (DiscoveredTcpEndpoints) lookupConnectionEndpointsSuccess;
624        }
625
626        @Override
627        protected void afterFiltersClosed() {
628            pendingInputFilterData = pendingOutputFilterData = true;
629            afterOutgoingElementsQueueModified();
630        }
631
632        @Override
633        protected void disconnect() {
634            XmppTcpTransportModule.this.closeSocketAndCleanup();
635        }
636
637        @Override
638        protected void notifyAboutNewOutgoingElements() {
639            afterOutgoingElementsQueueModified();
640        }
641
642        @Override
643        public SSLSession getSslSession() {
644            TlsState tlsState = XmppTcpTransportModule.this.tlsState;
645            if (tlsState == null) {
646                return null;
647            }
648
649            return tlsState.engine.getSession();
650        }
651
652        @Override
653        public boolean isConnected() {
654            SocketChannel socketChannel = XmppTcpTransportModule.this.socketChannel;
655            if (socketChannel == null) {
656                return false;
657            }
658
659            return socketChannel.isConnected();
660        }
661
662        @Override
663        public boolean isTransportSecured() {
664            final TlsState tlsState = XmppTcpTransportModule.this.tlsState;
665            return tlsState != null && tlsState.handshakeStatus == TlsHandshakeStatus.successful;
666        }
667
668        @Override
669        public XmppTcpTransportModule.Stats getStats() {
670            return XmppTcpTransportModule.this.getStats();
671        }
672
673        final class DiscoveredTcpEndpoints implements LookupConnectionEndpointsSuccess {
674            final RemoteXmppTcpConnectionEndpoints.Result<Rfc6120TcpRemoteConnectionEndpoint> result;
675            DiscoveredTcpEndpoints(RemoteXmppTcpConnectionEndpoints.Result<Rfc6120TcpRemoteConnectionEndpoint> result) {
676                this.result = result;
677            }
678        }
679
680        final class TcpEndpointDiscoveryFailed implements LookupConnectionEndpointsFailed {
681            final List<RemoteConnectionEndpointLookupFailure> lookupFailures;
682            TcpEndpointDiscoveryFailed(RemoteXmppTcpConnectionEndpoints.Result<Rfc6120TcpRemoteConnectionEndpoint> result) {
683                lookupFailures = result.lookupFailures;
684            }
685        }
686    }
687
688    private void afterOutgoingElementsQueueModified() {
689        final SelectionKeyAttachment selectionKeyAttachment = this.selectionKeyAttachment;
690        if (selectionKeyAttachment != null && selectionKeyAttachment.isReactorThreadRacing()) {
691            // A reactor thread is already racing to the channel selected callback and will take care of this.
692            reactorThreadAlreadyRacing.incrementAndGet();
693            return;
694        }
695
696        afterOutgoingElementsQueueModifiedSetInterestOps.incrementAndGet();
697
698        // Add OP_WRITE to the interested Ops, since we have now new things to write. Note that this may cause
699        // multiple reactor threads to race to the channel selected callback in case we perform this right after
700        // a select() returned with this selection key in the selected-key set. Hence we use tryLock() in the
701        // channel selected callback to keep the invariant that only exactly one thread is performing the
702        // callback.
703        // Note that we need to perform setInterestedOps() *without* holding the channelSelectedCallbackLock, as
704        // otherwise the reactor thread racing to the channel selected callback may found the lock still locked, which
705        // would result in the outgoingElementsQueue not being handled.
706        connectionInternal.setInterestOps(selectionKey, SelectionKey.OP_WRITE | SelectionKey.OP_READ);
707    }
708
709    @Override
710    protected XmppTcpNioTransport getTransport() {
711        return tcpNioTransport;
712    }
713
714    static final class EstablishingTcpConnectionStateDescriptor extends StateDescriptor {
715        private EstablishingTcpConnectionStateDescriptor() {
716            super(XmppTcpTransportModule.EstablishingTcpConnectionState.class);
717            addPredeccessor(LookupRemoteConnectionEndpointsStateDescriptor.class);
718            addSuccessor(EstablishTlsStateDescriptor.class);
719            addSuccessor(ConnectedButUnauthenticatedStateDescriptor.class);
720        }
721
722        @Override
723        protected XmppTcpTransportModule.EstablishingTcpConnectionState constructState(ModularXmppClientToServerConnectionInternal connectionInternal) {
724            XmppTcpTransportModule tcpTransportModule = connectionInternal.connection.getConnectionModuleFor(XmppTcpTransportModuleDescriptor.class);
725            return tcpTransportModule.constructEstablishingTcpConnectionState(this, connectionInternal);
726        }
727    }
728
729    private EstablishingTcpConnectionState constructEstablishingTcpConnectionState(
730                    EstablishingTcpConnectionStateDescriptor stateDescriptor,
731                    ModularXmppClientToServerConnectionInternal connectionInternal) {
732        return new EstablishingTcpConnectionState(stateDescriptor, connectionInternal);
733    }
734
735    final class EstablishingTcpConnectionState extends State {
736        private EstablishingTcpConnectionState(EstablishingTcpConnectionStateDescriptor stateDescriptor,
737                        ModularXmppClientToServerConnectionInternal connectionInternal) {
738            super(stateDescriptor, connectionInternal);
739        }
740
741        @Override
742        public StateTransitionResult.AttemptResult transitionInto(WalkStateGraphContext walkStateGraphContext)
743                        throws InterruptedException, IOException, SmackException, XMPPException {
744            // The fields inetSocketAddress and failedAddresses are handed over from LookupHostAddresses to
745            // ConnectingToHost.
746            ConnectionAttemptState connectionAttemptState = new ConnectionAttemptState(connectionInternal, discoveredTcpEndpoints,
747                    this);
748            StateTransitionResult.Failure failure = connectionAttemptState.establishTcpConnection();
749            if (failure != null) {
750                return failure;
751            }
752
753            socketChannel = connectionAttemptState.socketChannel;
754            remoteAddress = (InetSocketAddress) socketChannel.socket().getRemoteSocketAddress();
755
756            selectionKey = connectionInternal.registerWithSelector(socketChannel, SelectionKey.OP_READ,
757                            XmppTcpTransportModule.this::onChannelSelected);
758            selectionKeyAttachment = (SelectionKeyAttachment) selectionKey.attachment();
759
760            connectionInternal.setTransport(tcpNioTransport);
761
762            connectionInternal.newStreamOpenWaitForFeaturesSequence("stream features after initial connection");
763
764            return new TcpSocketConnectedResult(remoteAddress);
765        }
766
767        @Override
768        public void resetState() {
769            closeSocketAndCleanup();
770        }
771    }
772
773    public static final class TcpSocketConnectedResult extends StateTransitionResult.Success {
774        private final InetSocketAddress remoteAddress;
775
776        private TcpSocketConnectedResult(InetSocketAddress remoteAddress) {
777            super("TCP connection established to " + remoteAddress);
778            this.remoteAddress = remoteAddress;
779        }
780
781        public InetSocketAddress getRemoteAddress() {
782            return remoteAddress;
783        }
784    }
785
786    public static final class TlsEstablishedResult extends StateTransitionResult.Success {
787
788        private TlsEstablishedResult(SSLEngine sslEngine) {
789            super("TLS established: " + sslEngine.getSession());
790        }
791    }
792
793    static final class EstablishTlsStateDescriptor extends StateDescriptor {
794        private EstablishTlsStateDescriptor() {
795            super(XmppTcpTransportModule.EstablishTlsState.class, "RFC 6120 ยง 5");
796            addSuccessor(ConnectedButUnauthenticatedStateDescriptor.class);
797            declarePrecedenceOver(ConnectedButUnauthenticatedStateDescriptor.class);
798        }
799
800        @Override
801        protected EstablishTlsState constructState(ModularXmppClientToServerConnectionInternal connectionInternal) {
802            XmppTcpTransportModule tcpTransportModule = connectionInternal.connection.getConnectionModuleFor(XmppTcpTransportModuleDescriptor.class);
803            return tcpTransportModule.constructEstablishingTlsState(this, connectionInternal);
804        }
805    }
806
807    private EstablishTlsState constructEstablishingTlsState(
808                    EstablishTlsStateDescriptor stateDescriptor,
809                    ModularXmppClientToServerConnectionInternal connectionInternal) {
810        return new EstablishTlsState(stateDescriptor, connectionInternal);
811    }
812
813    private final class EstablishTlsState extends State {
814        private EstablishTlsState(EstablishTlsStateDescriptor stateDescriptor,
815                        ModularXmppClientToServerConnectionInternal connectionInternal) {
816            super(stateDescriptor, connectionInternal);
817        }
818
819        @Override
820        public StateTransitionResult.TransitionImpossible isTransitionToPossible(WalkStateGraphContext walkStateGraphContext)
821                throws SecurityRequiredByClientException, SecurityRequiredByServerException {
822            StartTls startTlsFeature = connectionInternal.connection.getFeature(StartTls.class);
823            SecurityMode securityMode = connectionInternal.connection.getConfiguration().getSecurityMode();
824
825            switch (securityMode) {
826            case required:
827            case ifpossible:
828                if (startTlsFeature == null) {
829                    if (securityMode == SecurityMode.ifpossible) {
830                        return new StateTransitionResult.TransitionImpossibleReason("Server does not announce support for TLS and we do not required it");
831                    }
832                    throw new SecurityRequiredByClientException();
833                }
834                // Allows transition by returning null.
835                return null;
836            case disabled:
837                if (startTlsFeature != null && startTlsFeature.required()) {
838                    throw new SecurityRequiredByServerException();
839                }
840                return new StateTransitionResult.TransitionImpossibleReason("TLS disabled in client settings and server does not require it");
841            default:
842                throw new AssertionError("Unknown security mode: " + securityMode);
843            }
844        }
845
846        @Override
847        public StateTransitionResult.AttemptResult transitionInto(WalkStateGraphContext walkStateGraphContext)
848                        throws IOException, InterruptedException, SmackException, XMPPException {
849            connectionInternal.sendAndWaitForResponse(StartTls.INSTANCE, TlsProceed.class, TlsFailure.class);
850
851            SmackTlsContext smackTlsContext = connectionInternal.getSmackTlsContext();
852
853            tlsState = new TlsState(smackTlsContext);
854            connectionInternal.addXmppInputOutputFilter(tlsState);
855
856            channelSelectedCallbackLock.lock();
857            try {
858                pendingOutputFilterData = true;
859                // The beginHandshake() is possibly not really required here, but it does not hurt either.
860                tlsState.engine.beginHandshake();
861                tlsState.handshakeStatus = TlsHandshakeStatus.initiated;
862            } finally {
863                channelSelectedCallbackLock.unlock();
864            }
865            connectionInternal.setInterestOps(selectionKey, SelectionKey.OP_WRITE | SelectionKey.OP_READ);
866
867            try {
868                tlsState.waitForHandshakeFinished();
869            } catch (CertificateException e) {
870                throw new SmackCertificateException(e);
871            }
872
873            connectionInternal.newStreamOpenWaitForFeaturesSequence("stream features after TLS established");
874
875            return new TlsEstablishedResult(tlsState.engine);
876        }
877
878        @Override
879        public void resetState() {
880            tlsState = null;
881        }
882    }
883
884    private enum TlsHandshakeStatus {
885        initial,
886        initiated,
887        successful,
888        failed,
889    }
890
891    private static final Level SSL_ENGINE_DEBUG_LOG_LEVEL = Level.FINEST;
892
893    private static void debugLogSslEngineResult(String operation, SSLEngineResult result) {
894        if (!LOGGER.isLoggable(SSL_ENGINE_DEBUG_LOG_LEVEL)) {
895            return;
896        }
897
898        LOGGER.log(SSL_ENGINE_DEBUG_LOG_LEVEL, "SSLEngineResult of " + operation + "(): " + result);
899    }
900
901    private final class TlsState implements XmppInputOutputFilter {
902
903        private static final int MAX_PENDING_OUTPUT_BYTES = 8096;
904
905        private final SmackTlsContext smackTlsContext;
906        private final SSLEngine engine;
907
908        private TlsHandshakeStatus handshakeStatus = TlsHandshakeStatus.initial;
909        private SSLException handshakeException;
910
911        private ByteBuffer myNetData;
912        private ByteBuffer peerAppData;
913
914        private final List<ByteBuffer> pendingOutputData = new ArrayList<>();
915        private int pendingOutputBytes;
916        private ByteBuffer pendingInputData;
917
918        private final AtomicInteger pendingDelegatedTasks = new AtomicInteger();
919
920        private long wrapInBytes;
921        private long wrapOutBytes;
922
923        private long unwrapInBytes;
924        private long unwrapOutBytes;
925
926        private TlsState(SmackTlsContext smackTlsContext) throws IOException {
927            this.smackTlsContext = smackTlsContext;
928
929            // Call createSSLEngine()'s variant with two parameters as this allows for TLS session resumption.
930
931            // Note that it is not really clear what the value of peer host should be. It could be A) the XMPP service's
932            // domainpart or B) the DNS name of the host we are connecting to (usually the DNS SRV RR target name). While
933            // the javadoc of createSSLEngine(String, int) indicates with "Some cipher suites (such as Kerberos) require
934            // remote hostname information, in which case peerHost needs to be specified." that A should be used. TLS
935            // session resumption may would need or at least benefit from B. Variant A would also be required if the
936            // String is used for certificate verification. And it appears at least likely that TLS session resumption
937            // would not be hurt by using variant A. Therefore we currently use variant A.
938            // TODO: Should we use the ACE representation of the XMPP service domain? Compare with f60e4055ec529f0b8160acedf13275592ab10a4b
939            // If yes, then we should probably introduce getXmppServiceDomainAceEncodedIfPossible().
940            String peerHost = connectionInternal.connection.getConfiguration().getXMPPServiceDomain().toString();
941            engine = smackTlsContext.sslContext.createSSLEngine(peerHost, remoteAddress.getPort());
942            engine.setUseClientMode(true);
943
944            SSLSession session = engine.getSession();
945            int applicationBufferSize = session.getApplicationBufferSize();
946            int packetBufferSize = session.getPacketBufferSize();
947
948            myNetData = ByteBuffer.allocateDirect(packetBufferSize);
949            peerAppData = ByteBuffer.allocate(applicationBufferSize);
950        }
951
952        @Override
953        public OutputResult output(ByteBuffer outputData, boolean isFinalDataOfElement, boolean destinationAddressChanged,
954                boolean moreDataAvailable) throws SSLException {
955            if (outputData != null) {
956                pendingOutputData.add(outputData);
957                pendingOutputBytes += outputData.remaining();
958                if (moreDataAvailable && pendingOutputBytes < MAX_PENDING_OUTPUT_BYTES) {
959                    return OutputResult.NO_OUTPUT;
960                }
961            }
962
963            ByteBuffer[] outputDataArray = pendingOutputData.toArray(new ByteBuffer[pendingOutputData.size()]);
964
965            myNetData.clear();
966
967            while (true) {
968                SSLEngineResult result;
969                try {
970                    result = engine.wrap(outputDataArray, myNetData);
971                } catch (SSLException e) {
972                    handleSslException(e);
973                    throw e;
974                }
975
976                debugLogSslEngineResult("wrap", result);
977
978                SSLEngineResult.Status engineResultStatus = result.getStatus();
979
980                pendingOutputBytes -= result.bytesConsumed();
981
982                if (engineResultStatus == SSLEngineResult.Status.OK) {
983                    wrapInBytes += result.bytesConsumed();
984                    wrapOutBytes += result.bytesProduced();
985
986                    SSLEngineResult.HandshakeStatus handshakeStatus = handleHandshakeStatus(result);
987                    switch (handshakeStatus) {
988                        case NEED_UNWRAP:
989                            // NEED_UNWRAP means that we need to receive something in order to continue the handshake. The
990                            // standard channelSelectedCallback logic will take care of this, as there is eventually always
991                            // a interest to read from the socket.
992                            break;
993                        case NEED_WRAP:
994                            // Same as need task: Cycle the reactor.
995                        case NEED_TASK:
996                            // Note that we also set pendingOutputFilterData in the OutputResult in the NEED_TASK case, as
997                            // we also want to retry the wrap() operation above in this case.
998                            return new OutputResult(true, myNetData);
999                        default:
1000                            break;
1001                    }
1002                }
1003
1004                switch (engineResultStatus) {
1005                case OK:
1006                    // No need to outputData.compact() here, since we do not reuse the buffer.
1007                    // Clean up the pending output data.
1008                    pruneBufferList(pendingOutputData);
1009                    return new OutputResult(!pendingOutputData.isEmpty(), myNetData);
1010                case CLOSED:
1011                    pendingOutputData.clear();
1012                    return OutputResult.NO_OUTPUT;
1013                case BUFFER_OVERFLOW:
1014                    LOGGER.warning("SSLEngine status BUFFER_OVERFLOW, this is hopefully uncommon");
1015                    int outputDataRemaining = outputData != null ? outputData.remaining() : 0;
1016                    int newCapacity = (int) (1.3 * outputDataRemaining);
1017                    // If newCapacity would not increase myNetData, then double it.
1018                    if (newCapacity <= myNetData.capacity()) {
1019                        newCapacity = 2 * myNetData.capacity();
1020                    }
1021                    ByteBuffer newMyNetData = ByteBuffer.allocateDirect(newCapacity);
1022                    myNetData.flip();
1023                    newMyNetData.put(myNetData);
1024                    myNetData = newMyNetData;
1025                    continue;
1026                case BUFFER_UNDERFLOW:
1027                    throw new IllegalStateException(
1028                            "Buffer underflow as result of SSLEngine.wrap() should never happen");
1029                }
1030            }
1031        }
1032
1033        @SuppressWarnings("ReferenceEquality")
1034        @Override
1035        public ByteBuffer input(ByteBuffer inputData) throws SSLException {
1036            ByteBuffer accumulatedData;
1037            if (pendingInputData == null) {
1038                accumulatedData = inputData;
1039            } else {
1040                assert pendingInputData != inputData;
1041
1042                int accumulatedDataBytes = pendingInputData.remaining() + inputData.remaining();
1043                accumulatedData = ByteBuffer.allocate(accumulatedDataBytes);
1044                accumulatedData.put(pendingInputData)
1045                               .put(inputData)
1046                               .flip();
1047                pendingInputData = null;
1048            }
1049
1050            peerAppData.clear();
1051
1052            while (true) {
1053                SSLEngineResult result;
1054                try {
1055                    result = engine.unwrap(accumulatedData, peerAppData);
1056                } catch (SSLException e) {
1057                    handleSslException(e);
1058                    throw e;
1059                }
1060
1061                debugLogSslEngineResult("unwrap", result);
1062
1063                SSLEngineResult.Status engineResultStatus = result.getStatus();
1064
1065                if (engineResultStatus == SSLEngineResult.Status.OK) {
1066                    unwrapInBytes += result.bytesConsumed();
1067                    unwrapOutBytes += result.bytesProduced();
1068
1069                    SSLEngineResult.HandshakeStatus handshakeStatus = handleHandshakeStatus(result);
1070                    switch (handshakeStatus) {
1071                    case NEED_TASK:
1072                        // A delegated task is asynchronously running. Take care of the remaining accumulatedData.
1073                        addAsPendingInputData(accumulatedData);
1074                        // Return here, as the async task created by handleHandshakeStatus will continue calling the
1075                        // cannelSelectedCallback.
1076                        return null;
1077                    case NEED_UNWRAP:
1078                        continue;
1079                    case NEED_WRAP:
1080                        // NEED_WRAP means that the SSLEngine needs to send data, probably without consuming data.
1081                        // We exploit here the fact that the channelSelectedCallback is single threaded and that the
1082                        // input processing is after the output processing.
1083                        addAsPendingInputData(accumulatedData);
1084                        // Note that it is ok that we the provided argument for pending input filter data to channel
1085                        // selected callback is false, as setPendingInputFilterData() will have set the internal state
1086                        // boolean accordingly.
1087                        connectionInternal.asyncGo(() -> callChannelSelectedCallback(false, true));
1088                        // Do not break here, but instead return and let the asynchronously invoked
1089                        // callChannelSelectedCallback() do its work.
1090                        return null;
1091                    default:
1092                        break;
1093                    }
1094                }
1095
1096                switch (engineResultStatus) {
1097                case OK:
1098                    // SSLEngine's unwrap() may not consume all bytes from the source buffer. If this is the case, then
1099                    // simply perform another unwrap until accumlatedData has no remaining bytes.
1100                    if (accumulatedData.hasRemaining()) {
1101                        continue;
1102                    }
1103                    return peerAppData;
1104                case CLOSED:
1105                    return null;
1106                case BUFFER_UNDERFLOW:
1107                    // There were not enough source bytes available to make a complete packet. Let it in
1108                    // pendingInputData. Note that we do not resize SSLEngine's source buffer - inputData in our case -
1109                    // as it is not possible.
1110                    addAsPendingInputData(accumulatedData);
1111                    return null;
1112                case BUFFER_OVERFLOW:
1113                    int applicationBufferSize = engine.getSession().getApplicationBufferSize();
1114                    assert peerAppData.remaining() < applicationBufferSize;
1115                    peerAppData = ByteBuffer.allocate(applicationBufferSize);
1116                    continue;
1117                }
1118            }
1119        }
1120
1121        private void addAsPendingInputData(ByteBuffer byteBuffer) {
1122            // Note that we can not simply write
1123            // pendingInputData = byteBuffer;
1124            // we have to copy the provided byte buffer, because it is possible that this byteBuffer is re-used by some
1125            // higher layer. That is, here 'byteBuffer' is typically 'incomingBuffer', which is a direct buffer only
1126            // allocated once per connection for performance reasons and hence re-used for read() calls.
1127            pendingInputData = ByteBuffer.allocate(byteBuffer.remaining());
1128            pendingInputData.put(byteBuffer).flip();
1129
1130            pendingInputFilterData = pendingInputData.hasRemaining();
1131        }
1132
1133        private SSLEngineResult.HandshakeStatus handleHandshakeStatus(SSLEngineResult sslEngineResult) {
1134            SSLEngineResult.HandshakeStatus handshakeStatus = sslEngineResult.getHandshakeStatus();
1135            switch (handshakeStatus) {
1136            case NEED_TASK:
1137                while (true) {
1138                    final Runnable delegatedTask = engine.getDelegatedTask();
1139                    if (delegatedTask == null) {
1140                        break;
1141                    }
1142                    sslEngineDelegatedTasks++;
1143                    int currentPendingDelegatedTasks = pendingDelegatedTasks.incrementAndGet();
1144                    if (currentPendingDelegatedTasks > maxPendingSslEngineDelegatedTasks) {
1145                        maxPendingSslEngineDelegatedTasks = currentPendingDelegatedTasks;
1146                    }
1147
1148                    Runnable wrappedDelegatedTask = () -> {
1149                        delegatedTask.run();
1150                        int wrappedCurrentPendingDelegatedTasks = pendingDelegatedTasks.decrementAndGet();
1151                        if (wrappedCurrentPendingDelegatedTasks == 0) {
1152                            callChannelSelectedCallback(true, true);
1153                        }
1154                    };
1155                    connectionInternal.asyncGo(wrappedDelegatedTask);
1156                }
1157                break;
1158            case FINISHED:
1159                onHandshakeFinished();
1160                break;
1161            default:
1162                break;
1163            }
1164
1165            SSLEngineResult.HandshakeStatus afterHandshakeStatus = engine.getHandshakeStatus();
1166            return afterHandshakeStatus;
1167        }
1168
1169        private void handleSslException(SSLException e) {
1170            handshakeException = e;
1171            handshakeStatus = TlsHandshakeStatus.failed;
1172            connectionInternal.notifyWaitingThreads();
1173        }
1174
1175        private void onHandshakeFinished() {
1176            handshakeStatus = TlsHandshakeStatus.successful;
1177            connectionInternal.notifyWaitingThreads();
1178        }
1179
1180        private boolean isHandshakeFinished() {
1181            return handshakeStatus == TlsHandshakeStatus.successful || handshakeStatus == TlsHandshakeStatus.failed;
1182        }
1183
1184        private void waitForHandshakeFinished() throws InterruptedException, CertificateException, SSLException, SmackException, XMPPException {
1185            connectionInternal.waitForConditionOrThrowConnectionException(() -> isHandshakeFinished(), "TLS handshake to finish");
1186
1187            if (handshakeStatus == TlsHandshakeStatus.failed) {
1188                throw handshakeException;
1189            }
1190
1191            assert handshakeStatus == TlsHandshakeStatus.successful;
1192
1193            if (smackTlsContext.daneVerifier != null) {
1194                smackTlsContext.daneVerifier.finish(engine.getSession());
1195            }
1196        }
1197
1198        @Override
1199        public Object getStats() {
1200            return new TlsStateStats(this);
1201        }
1202
1203        @Override
1204        public void closeInputOutput() {
1205            engine.closeOutbound();
1206            try {
1207                engine.closeInbound();
1208            } catch (SSLException e) {
1209                LOGGER.log(Level.FINEST,
1210                        "SSLException when closing inbound TLS session. This can likely be ignored if a possible truncation attack is suggested."
1211                        + " You may want to ask your XMPP server vendor to implement a clean TLS session shutdown sending close_notify after </stream>",
1212                        e);
1213            }
1214        }
1215
1216        @Override
1217        public void waitUntilInputOutputClosed() throws IOException, CertificateException, InterruptedException,
1218                SmackException, XMPPException {
1219            waitForHandshakeFinished();
1220        }
1221
1222        @Override
1223        public String getFilterName() {
1224            return "TLS (" + engine + ')';
1225        }
1226    }
1227
1228    public static final class TlsStateStats {
1229        public final long wrapInBytes;
1230        public final long wrapOutBytes;
1231        public final double wrapRatio;
1232
1233        public final long unwrapInBytes;
1234        public final long unwrapOutBytes;
1235        public final double unwrapRatio;
1236
1237        private TlsStateStats(TlsState tlsState) {
1238            wrapOutBytes = tlsState.wrapOutBytes;
1239            wrapInBytes = tlsState.wrapInBytes;
1240            wrapRatio = (double) wrapOutBytes / wrapInBytes;
1241
1242            unwrapOutBytes = tlsState.unwrapOutBytes;
1243            unwrapInBytes = tlsState.unwrapInBytes;
1244            unwrapRatio = (double) unwrapInBytes / unwrapOutBytes;
1245        }
1246
1247        private transient String toStringCache;
1248
1249        @Override
1250        public String toString() {
1251            if (toStringCache != null) {
1252                return toStringCache;
1253            }
1254
1255            toStringCache =
1256                      "wrap-in-bytes: " + wrapInBytes + '\n'
1257                    + "wrap-out-bytes: " + wrapOutBytes + '\n'
1258                    + "wrap-ratio: " + wrapRatio + '\n'
1259                    + "unwrap-in-bytes: " + unwrapInBytes + '\n'
1260                    + "unwrap-out-bytes: " + unwrapOutBytes + '\n'
1261                    + "unwrap-ratio: " + unwrapRatio + '\n'
1262                    ;
1263
1264            return toStringCache;
1265        }
1266    }
1267
1268    private void callChannelSelectedCallback(boolean setPendingInputFilterData, boolean setPendingOutputFilterData) {
1269        final SocketChannel channel = socketChannel;
1270        final SelectionKey key = selectionKey;
1271        if (channel == null || key == null) {
1272            LOGGER.info("Not calling channel selected callback because the connection was eventually disconnected");
1273            return;
1274        }
1275
1276        channelSelectedCallbackLock.lock();
1277        try {
1278            // Note that it is important that we send the pending(Input|Output)FilterData flags while holding the lock.
1279            if (setPendingInputFilterData) {
1280                pendingInputFilterData = true;
1281            }
1282            if (setPendingOutputFilterData) {
1283                pendingOutputFilterData = true;
1284            }
1285
1286            onChannelSelected(channel, key);
1287        } finally {
1288            channelSelectedCallbackLock.unlock();
1289        }
1290    }
1291
1292    private void closeSocketAndCleanup() {
1293        final SelectionKey selectionKey = this.selectionKey;
1294        if (selectionKey != null) {
1295            selectionKey.cancel();
1296        }
1297        final SocketChannel socketChannel = this.socketChannel;
1298        if (socketChannel != null) {
1299            try {
1300                socketChannel.close();
1301            } catch (IOException e) {
1302
1303            }
1304        }
1305
1306        this.selectionKey = null;
1307        this.socketChannel = null;
1308
1309        selectionKeyAttachment = null;
1310        remoteAddress = null;
1311    }
1312
1313    private static List<? extends Buffer> pruneBufferList(Collection<? extends Buffer> buffers) {
1314        return CollectionUtil.removeUntil(buffers, b -> b.hasRemaining());
1315    }
1316
1317    public XmppTcpTransportModule.Stats getStats() {
1318        return new Stats(this);
1319    }
1320
1321    public static final class Stats extends XmppClientToServerTransport.Stats {
1322        public final long totalBytesWritten;
1323        public final long totalBytesWrittenBeforeFilter;
1324        public final double writeRatio;
1325
1326        public final long totalBytesRead;
1327        public final long totalBytesReadAfterFilter;
1328        public final double readRatio;
1329
1330        public final long handledChannelSelectedCallbacks;
1331        public final long setWriteInterestAfterChannelSelectedCallback;
1332        public final long reactorThreadAlreadyRacing;
1333        public final long afterOutgoingElementsQueueModifiedSetInterestOps;
1334        public final long rejectedChannelSelectedCallbacks;
1335        public final long totalCallbackRequests;
1336        public final long callbackPreemtBecauseBytesWritten;
1337        public final long callbackPreemtBecauseBytesRead;
1338        public final int sslEngineDelegatedTasks;
1339        public final int maxPendingSslEngineDelegatedTasks;
1340
1341        private Stats(XmppTcpTransportModule connection) {
1342            totalBytesWritten = connection.totalBytesWritten;
1343            totalBytesWrittenBeforeFilter = connection.totalBytesWrittenBeforeFilter;
1344            writeRatio = (double) totalBytesWritten / totalBytesWrittenBeforeFilter;
1345
1346            totalBytesReadAfterFilter = connection.totalBytesReadAfterFilter;
1347            totalBytesRead = connection.totalBytesRead;
1348            readRatio = (double) totalBytesRead / totalBytesReadAfterFilter;
1349
1350            handledChannelSelectedCallbacks = connection.handledChannelSelectedCallbacks;
1351            setWriteInterestAfterChannelSelectedCallback = connection.setWriteInterestAfterChannelSelectedCallback.get();
1352            reactorThreadAlreadyRacing = connection.reactorThreadAlreadyRacing.get();
1353            afterOutgoingElementsQueueModifiedSetInterestOps = connection.afterOutgoingElementsQueueModifiedSetInterestOps
1354                    .get();
1355            rejectedChannelSelectedCallbacks = connection.rejectedChannelSelectedCallbacks.get();
1356
1357            totalCallbackRequests = handledChannelSelectedCallbacks + rejectedChannelSelectedCallbacks;
1358
1359            callbackPreemtBecauseBytesRead = connection.callbackPreemtBecauseBytesRead;
1360            callbackPreemtBecauseBytesWritten = connection.callbackPreemtBecauseBytesWritten;
1361
1362            sslEngineDelegatedTasks = connection.sslEngineDelegatedTasks;
1363            maxPendingSslEngineDelegatedTasks = connection.maxPendingSslEngineDelegatedTasks;
1364        }
1365
1366        private transient String toStringCache;
1367
1368        @Override
1369        public String toString() {
1370            if (toStringCache != null) {
1371                return toStringCache;
1372            }
1373
1374            toStringCache =
1375              "Total bytes\n"
1376            + "recv: " + totalBytesRead + '\n'
1377            + "send: " + totalBytesWritten + '\n'
1378            + "recv-aft-filter: " + totalBytesReadAfterFilter + '\n'
1379            + "send-bef-filter: " + totalBytesWrittenBeforeFilter + '\n'
1380            + "read-ratio: " + readRatio + '\n'
1381            + "write-ratio: " + writeRatio + '\n'
1382            + "Events\n"
1383            + "total-callback-requests: " + totalCallbackRequests + '\n'
1384            + "handled-channel-selected-callbacks: " + handledChannelSelectedCallbacks + '\n'
1385            + "rejected-channel-selected-callbacks: " + rejectedChannelSelectedCallbacks + '\n'
1386            + "set-write-interest-after-callback: " + setWriteInterestAfterChannelSelectedCallback + '\n'
1387            + "reactor-thread-already-racing: " + reactorThreadAlreadyRacing + '\n'
1388            + "after-queue-modified-set-interest-ops: " + afterOutgoingElementsQueueModifiedSetInterestOps + '\n'
1389            + "callback-preemt-because-bytes-read: " + callbackPreemtBecauseBytesRead + '\n'
1390            + "callback-preemt-because-bytes-written: " + callbackPreemtBecauseBytesWritten + '\n'
1391            + "ssl-engine-delegated-tasks: " + sslEngineDelegatedTasks + '\n'
1392            + "max-pending-ssl-engine-delegated-tasks: " + maxPendingSslEngineDelegatedTasks + '\n'
1393            ;
1394
1395            return toStringCache;
1396        }
1397    }
1398}