001/**
002 *
003 * Copyright 2018-2021 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.fsm;
018
019import java.io.PrintWriter;
020import java.lang.reflect.Constructor;
021import java.lang.reflect.InvocationTargetException;
022import java.util.ArrayList;
023import java.util.Collection;
024import java.util.Collections;
025import java.util.HashMap;
026import java.util.HashSet;
027import java.util.Iterator;
028import java.util.List;
029import java.util.Map;
030import java.util.Set;
031
032import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnection.DisconnectedStateDescriptor;
033import org.jivesoftware.smack.c2s.internal.ModularXmppClientToServerConnectionInternal;
034import org.jivesoftware.smack.util.Consumer;
035import org.jivesoftware.smack.util.MultiMap;
036
037/**
038 * Smack's utility API for Finite State Machines (FSM).
039 *
040 * <p>
041 * Thanks to Andreas Fried for the fun and successful bug hunting session.
042 * </p>
043 *
044 * @author Florian Schmaus
045 *
046 */
047public class StateDescriptorGraph {
048
049    private static GraphVertex<StateDescriptor> addNewStateDescriptorGraphVertex(
050                    Class<? extends StateDescriptor> stateDescriptorClass,
051                    Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes)
052                    throws InstantiationException, IllegalAccessException, IllegalArgumentException,
053                    InvocationTargetException, NoSuchMethodException, SecurityException {
054        Constructor<? extends StateDescriptor> stateDescriptorConstructor = stateDescriptorClass.getDeclaredConstructor();
055        stateDescriptorConstructor.setAccessible(true);
056        StateDescriptor stateDescriptor = stateDescriptorConstructor.newInstance();
057        GraphVertex<StateDescriptor> graphVertexStateDescriptor = new GraphVertex<>(stateDescriptor);
058
059        GraphVertex<StateDescriptor> previous = graphVertexes.put(stateDescriptorClass, graphVertexStateDescriptor);
060        assert previous == null;
061
062        return graphVertexStateDescriptor;
063    }
064
065    private static final class HandleStateDescriptorGraphVertexContext {
066        private final Set<Class<? extends StateDescriptor>> handledStateDescriptors = new HashSet<>();
067        Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes;
068        MultiMap<Class<? extends StateDescriptor>, Class<? extends StateDescriptor>> inferredForwardEdges;
069
070        private HandleStateDescriptorGraphVertexContext(
071                        Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes,
072                        MultiMap<Class<? extends StateDescriptor>, Class<? extends StateDescriptor>> inferredForwardEdges) {
073            this.graphVertexes = graphVertexes;
074            this.inferredForwardEdges = inferredForwardEdges;
075        }
076
077        private boolean recurseInto(Class<? extends StateDescriptor> stateDescriptorClass) {
078            boolean wasAdded = handledStateDescriptors.add(stateDescriptorClass);
079            boolean alreadyHandled = !wasAdded;
080            return alreadyHandled;
081        }
082
083        private GraphVertex<StateDescriptor> getOrConstruct(Class<? extends StateDescriptor> stateDescriptorClass)
084                        throws InstantiationException, IllegalAccessException, IllegalArgumentException,
085                        InvocationTargetException, NoSuchMethodException, SecurityException {
086            GraphVertex<StateDescriptor> graphVertexStateDescriptor = graphVertexes.get(stateDescriptorClass);
087
088            if (graphVertexStateDescriptor == null) {
089                graphVertexStateDescriptor = addNewStateDescriptorGraphVertex(stateDescriptorClass, graphVertexes);
090
091                for (Class<? extends StateDescriptor> inferredSuccessor : inferredForwardEdges.getAll(
092                                stateDescriptorClass)) {
093                    graphVertexStateDescriptor.getElement().addSuccessor(inferredSuccessor);
094                }
095            }
096
097            return graphVertexStateDescriptor;
098        }
099    }
100
101    private static void handleStateDescriptorGraphVertex(GraphVertex<StateDescriptor> node,
102                    HandleStateDescriptorGraphVertexContext context,
103                    boolean failOnUnknownStates)
104                    throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException {
105        Class<? extends StateDescriptor> stateDescriptorClass = node.element.getClass();
106        boolean alreadyHandled = context.recurseInto(stateDescriptorClass);
107        if (alreadyHandled) {
108            return;
109        }
110
111        Set<Class<? extends StateDescriptor>> successorClasses = node.element.getSuccessors();
112        int numSuccessors = successorClasses.size();
113
114        Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> successorStateDescriptors = new HashMap<>(
115                        numSuccessors);
116        for (Class<? extends StateDescriptor> successorClass : successorClasses) {
117            GraphVertex<StateDescriptor> successorGraphNode = context.getOrConstruct(successorClass);
118            successorStateDescriptors.put(successorClass, successorGraphNode);
119        }
120
121        switch (numSuccessors) {
122        case 0:
123            throw new IllegalStateException("State " + stateDescriptorClass + " has no successor");
124        case 1:
125            GraphVertex<StateDescriptor> soleSuccessorNode = successorStateDescriptors.values().iterator().next();
126            node.addOutgoingEdge(soleSuccessorNode);
127            handleStateDescriptorGraphVertex(soleSuccessorNode, context, failOnUnknownStates);
128            return;
129        }
130
131        // We hit a state with multiple successors, perform a topological sort on the successors first.
132        // Process the information regarding subordinates and superiors states.
133
134        // The preference graph is the graph where the precedence information of all successors is stored, which we will
135        // topologically sort to find out which successor we should try first. It is a further new graph we use solely in
136        // this step for every node. The graph is represented as map. There is no special marker for the initial node
137        // as it is not required for the topological sort performed later.
138        Map<Class<? extends StateDescriptor>, GraphVertex<Class<? extends StateDescriptor>>> preferenceGraph = new HashMap<>(numSuccessors);
139
140        // Iterate over all successor states of the current state.
141        for (GraphVertex<StateDescriptor> successorStateDescriptorGraphNode : successorStateDescriptors.values()) {
142            StateDescriptor successorStateDescriptor = successorStateDescriptorGraphNode.element;
143            Class<? extends StateDescriptor> successorStateDescriptorClass = successorStateDescriptor.getClass();
144            for (Class<? extends StateDescriptor> subordinateClass : successorStateDescriptor.getSubordinates()) {
145                if (failOnUnknownStates && !successorClasses.contains(subordinateClass)) {
146                    throw new IllegalStateException(successorStateDescriptor + " points to a subordinate '" + subordinateClass + "' which is not part of the successor set");
147                }
148
149                GraphVertex<Class<? extends StateDescriptor>> superiorClassNode = lookupAndCreateIfRequired(
150                                preferenceGraph, successorStateDescriptorClass);
151                GraphVertex<Class<? extends StateDescriptor>> subordinateClassNode = lookupAndCreateIfRequired(
152                                preferenceGraph, subordinateClass);
153
154                superiorClassNode.addOutgoingEdge(subordinateClassNode);
155            }
156            for (Class<? extends StateDescriptor> superiorClass : successorStateDescriptor.getSuperiors()) {
157                if (failOnUnknownStates && !successorClasses.contains(superiorClass)) {
158                    throw new IllegalStateException(successorStateDescriptor + " points to a superior '" + superiorClass
159                                    + "' which is not part of the successor set");
160                }
161
162                GraphVertex<Class<? extends StateDescriptor>> subordinateClassNode = lookupAndCreateIfRequired(
163                                preferenceGraph, successorStateDescriptorClass);
164                GraphVertex<Class<? extends StateDescriptor>> superiorClassNode = lookupAndCreateIfRequired(
165                                preferenceGraph, superiorClass);
166
167                superiorClassNode.addOutgoingEdge(subordinateClassNode);
168            }
169        }
170
171        // Perform a topological sort which returns the state descriptor classes sorted by their priority. Highest
172        // priority state descriptors first.
173        List<GraphVertex<Class<? extends StateDescriptor>>> sortedSuccessors = topologicalSort(preferenceGraph.values());
174
175        // Handle the successor nodes which have not preference information available. Simply append them to the end of
176        // the sorted successor list.
177        outerloop: for (Class<? extends StateDescriptor> successorStateDescriptor : successorClasses) {
178            for (GraphVertex<Class<? extends StateDescriptor>> sortedSuccessor : sortedSuccessors) {
179                if (sortedSuccessor.getElement() == successorStateDescriptor) {
180                    continue outerloop;
181                }
182            }
183
184            sortedSuccessors.add(new GraphVertex<>(successorStateDescriptor));
185        }
186
187        for (GraphVertex<Class<? extends StateDescriptor>> successor : sortedSuccessors) {
188            GraphVertex<StateDescriptor> successorVertex = successorStateDescriptors.get(successor.element);
189            if (successorVertex == null) {
190                // The successor does not exist, probably because its module was not enabled.
191                continue;
192            }
193            node.addOutgoingEdge(successorVertex);
194
195            // Recurse further.
196            handleStateDescriptorGraphVertex(successorVertex, context, failOnUnknownStates);
197        }
198    }
199
200    public static GraphVertex<StateDescriptor> constructStateDescriptorGraph(
201                    Set<Class<? extends StateDescriptor>> backwardEdgeStateDescriptors,
202                    boolean failOnUnknownStates)
203                    throws InstantiationException, IllegalAccessException, IllegalArgumentException,
204                    InvocationTargetException, NoSuchMethodException, SecurityException {
205        Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes = new HashMap<>();
206
207        final Class<? extends StateDescriptor> initialStatedescriptorClass = DisconnectedStateDescriptor.class;
208        GraphVertex<StateDescriptor> initialNode = addNewStateDescriptorGraphVertex(initialStatedescriptorClass, graphVertexes);
209
210        MultiMap<Class<? extends StateDescriptor>, Class<? extends StateDescriptor>> inferredForwardEdges = new MultiMap<>();
211        for (Class<? extends StateDescriptor> backwardsEdge : backwardEdgeStateDescriptors) {
212            GraphVertex<StateDescriptor> graphVertexStateDescriptor = addNewStateDescriptorGraphVertex(backwardsEdge, graphVertexes);
213
214            for (Class<? extends StateDescriptor> predecessor : graphVertexStateDescriptor.getElement().getPredeccessors()) {
215                inferredForwardEdges.put(predecessor, backwardsEdge);
216            }
217        }
218        // Ensure that the initial node has their successors inferred.
219        for (Class<? extends StateDescriptor> inferredSuccessorOfInitialStateDescriptor : inferredForwardEdges.getAll(initialStatedescriptorClass)) {
220            initialNode.getElement().addSuccessor(inferredSuccessorOfInitialStateDescriptor);
221        }
222
223        HandleStateDescriptorGraphVertexContext context = new HandleStateDescriptorGraphVertexContext(graphVertexes, inferredForwardEdges);
224        handleStateDescriptorGraphVertex(initialNode, context, failOnUnknownStates);
225
226        return initialNode;
227    }
228
229    private static GraphVertex<State> convertToStateGraph(GraphVertex<StateDescriptor> stateDescriptorVertex,
230                    ModularXmppClientToServerConnectionInternal connectionInternal, Map<StateDescriptor, GraphVertex<State>> handledStateDescriptors) {
231        StateDescriptor stateDescriptor = stateDescriptorVertex.getElement();
232        GraphVertex<State> stateVertex = handledStateDescriptors.get(stateDescriptor);
233        if (stateVertex != null) {
234            return stateVertex;
235        }
236
237        State state = stateDescriptor.constructState(connectionInternal);
238        stateVertex = new GraphVertex<>(state);
239        handledStateDescriptors.put(stateDescriptor, stateVertex);
240        for (GraphVertex<StateDescriptor> successorStateDescriptorVertex : stateDescriptorVertex.getOutgoingEdges()) {
241            GraphVertex<State> successorStateVertex = convertToStateGraph(successorStateDescriptorVertex, connectionInternal, handledStateDescriptors);
242            // It is important that we keep the order of the edges. This should do it.
243            stateVertex.addOutgoingEdge(successorStateVertex);
244        }
245
246        return stateVertex;
247    }
248
249    public static GraphVertex<State> convertToStateGraph(GraphVertex<StateDescriptor> initialStateDescriptor,
250                    ModularXmppClientToServerConnectionInternal connectionInternal) {
251        Map<StateDescriptor, GraphVertex<State>> handledStateDescriptors = new HashMap<>();
252        GraphVertex<State> initialState = convertToStateGraph(initialStateDescriptor, connectionInternal,
253                        handledStateDescriptors);
254        return initialState;
255    }
256
257    // Graph API after here.
258    // This API could possibly factored out into an extra package/class, but then we will probably need a builder for
259    // the graph vertex in order to keep it immutable.
260    public static final class GraphVertex<E> {
261        private final E element;
262        private final List<GraphVertex<E>> outgoingEdges = new ArrayList<>();
263
264        private VertexColor color = VertexColor.white;
265
266        private GraphVertex(E element) {
267            this.element = element;
268        }
269
270        private void addOutgoingEdge(GraphVertex<E> vertex) {
271            assert vertex != null;
272            if (outgoingEdges.contains(vertex)) {
273                throw new IllegalArgumentException("This " + this + " already has an outgoing edge to " + vertex);
274            }
275            outgoingEdges.add(vertex);
276        }
277
278        public E getElement() {
279            return element;
280        }
281
282        public List<GraphVertex<E>> getOutgoingEdges() {
283            return Collections.unmodifiableList(outgoingEdges);
284        }
285
286        private enum VertexColor {
287            white,
288            grey,
289            black,
290        }
291
292        @Override
293        public String toString() {
294            return toString(true);
295        }
296
297        public String toString(boolean includeOutgoingEdges) {
298            StringBuilder sb = new StringBuilder();
299            sb.append("GraphVertex " + element + " [color=" + color
300                            + ", identityHashCode=" + System.identityHashCode(this)
301                            + ", outgoingEdgeCount=" + outgoingEdges.size());
302
303            if (includeOutgoingEdges) {
304                sb.append(", outgoingEdges={");
305
306                for (Iterator<GraphVertex<E>> it = outgoingEdges.iterator(); it.hasNext();) {
307                    GraphVertex<E> outgoingEdgeVertex = it.next();
308                    sb.append(outgoingEdgeVertex.toString(false));
309                    if (it.hasNext()) {
310                        sb.append(", ");
311                    }
312                }
313                sb.append('}');
314            }
315
316            sb.append(']');
317            return sb.toString();
318        }
319    }
320
321    private static GraphVertex<Class<? extends StateDescriptor>> lookupAndCreateIfRequired(
322                    Map<Class<? extends StateDescriptor>, GraphVertex<Class<? extends StateDescriptor>>> map,
323                    Class<? extends StateDescriptor> clazz) {
324        GraphVertex<Class<? extends StateDescriptor>> vertex = map.get(clazz);
325        if (vertex == null) {
326            vertex = new GraphVertex<>(clazz);
327            map.put(clazz, vertex);
328        }
329        return vertex;
330    }
331
332    private static <E> List<GraphVertex<E>> topologicalSort(Collection<GraphVertex<E>> vertexes) {
333        List<GraphVertex<E>> res = new ArrayList<>();
334        dfs(vertexes, vertex -> res.add(0, vertex), null);
335        return res;
336    }
337
338    private static <E> void dfsVisit(GraphVertex<E> vertex, Consumer<GraphVertex<E>> dfsFinishedVertex,
339                    DfsEdgeFound<E> dfsEdgeFound) {
340        vertex.color = GraphVertex.VertexColor.grey;
341
342        final int totalEdgeCount = vertex.getOutgoingEdges().size();
343
344        int edgeCount = 0;
345
346        for (GraphVertex<E> successorVertex : vertex.getOutgoingEdges()) {
347            edgeCount++;
348            if (dfsEdgeFound != null) {
349                dfsEdgeFound.onEdgeFound(vertex, successorVertex, edgeCount, totalEdgeCount);
350            }
351            if (successorVertex.color == GraphVertex.VertexColor.white) {
352                dfsVisit(successorVertex, dfsFinishedVertex, dfsEdgeFound);
353            }
354        }
355
356        vertex.color = GraphVertex.VertexColor.black;
357        if (dfsFinishedVertex != null) {
358            dfsFinishedVertex.accept(vertex);
359        }
360    }
361
362    private static <E> void dfs(Collection<GraphVertex<E>> vertexes, Consumer<GraphVertex<E>> dfsFinishedVertex,
363                    DfsEdgeFound<E> dfsEdgeFound) {
364        for (GraphVertex<E> vertex : vertexes) {
365            if (vertex.color == GraphVertex.VertexColor.white) {
366                dfsVisit(vertex, dfsFinishedVertex, dfsEdgeFound);
367            }
368        }
369    }
370
371    public static <E> void stateDescriptorGraphToDot(Collection<GraphVertex<StateDescriptor>> vertexes,
372                    PrintWriter dotOut, boolean breakStateName) {
373        dotOut.append("digraph {\n");
374        dfs(vertexes,
375                finishedVertex -> {
376                   boolean isMultiVisitState = finishedVertex.element.isMultiVisitState();
377                   boolean isFinalState = finishedVertex.element.isFinalState();
378                   boolean isNotImplemented = finishedVertex.element.isNotImplemented();
379
380                   String style = null;
381                   if (isMultiVisitState) {
382                       style = "bold";
383                   } else if (isFinalState) {
384                       style = "filled";
385                   } else if (isNotImplemented) {
386                       style = "dashed";
387                   }
388
389                   if (style == null) {
390                       return;
391                   }
392
393                   dotOut.append('"')
394                       .append(finishedVertex.element.getFullStateName(breakStateName))
395                       .append("\" [ ")
396                       .append("style=")
397                       .append(style)
398                       .append(" ]\n");
399               },
400               (from, to, edgeId, totalEdgeCount) -> {
401                   dotOut.append("  \"")
402                       .append(from.element.getFullStateName(breakStateName))
403                       .append("\" -> \"")
404                       .append(to.element.getFullStateName(breakStateName))
405                       .append('"');
406                   if (totalEdgeCount > 1) {
407                       // Note that 'dot' requires *double* quotes to enclose the value.
408                       dotOut.append(" [xlabel=\"")
409                       .append(Integer.toString(edgeId))
410                       .append("\"]");
411                   }
412                   dotOut.append(";\n");
413               });
414        dotOut.append("}\n");
415    }
416
417    private interface DfsEdgeFound<E> {
418        void onEdgeFound(GraphVertex<E> from, GraphVertex<E> to, int edgeId, int totalEdgeCount);
419    }
420}