StateDescriptorGraph.java

  1. /**
  2.  *
  3.  * Copyright 2018-2021 Florian Schmaus
  4.  *
  5.  * Licensed under the Apache License, Version 2.0 (the "License");
  6.  * you may not use this file except in compliance with the License.
  7.  * You may obtain a copy of the License at
  8.  *
  9.  *     http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.jivesoftware.smack.fsm;

  18. import java.io.PrintWriter;
  19. import java.lang.reflect.Constructor;
  20. import java.lang.reflect.InvocationTargetException;
  21. import java.util.ArrayList;
  22. import java.util.Collection;
  23. import java.util.Collections;
  24. import java.util.HashMap;
  25. import java.util.HashSet;
  26. import java.util.Iterator;
  27. import java.util.List;
  28. import java.util.Map;
  29. import java.util.Set;

  30. import org.jivesoftware.smack.c2s.ModularXmppClientToServerConnection.DisconnectedStateDescriptor;
  31. import org.jivesoftware.smack.c2s.internal.ModularXmppClientToServerConnectionInternal;
  32. import org.jivesoftware.smack.util.Consumer;
  33. import org.jivesoftware.smack.util.MultiMap;

  34. /**
  35.  * Smack's utility API for Finite State Machines (FSM).
  36.  *
  37.  * <p>
  38.  * Thanks to Andreas Fried for the fun and successful bug hunting session.
  39.  * </p>
  40.  *
  41.  * @author Florian Schmaus
  42.  *
  43.  */
  44. public class StateDescriptorGraph {

  45.     private static GraphVertex<StateDescriptor> addNewStateDescriptorGraphVertex(
  46.                     Class<? extends StateDescriptor> stateDescriptorClass,
  47.                     Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes)
  48.                     throws InstantiationException, IllegalAccessException, IllegalArgumentException,
  49.                     InvocationTargetException, NoSuchMethodException, SecurityException {
  50.         Constructor<? extends StateDescriptor> stateDescriptorConstructor = stateDescriptorClass.getDeclaredConstructor();
  51.         stateDescriptorConstructor.setAccessible(true);
  52.         StateDescriptor stateDescriptor = stateDescriptorConstructor.newInstance();
  53.         GraphVertex<StateDescriptor> graphVertexStateDescriptor = new GraphVertex<>(stateDescriptor);

  54.         GraphVertex<StateDescriptor> previous = graphVertexes.put(stateDescriptorClass, graphVertexStateDescriptor);
  55.         assert previous == null;

  56.         return graphVertexStateDescriptor;
  57.     }

  58.     private static final class HandleStateDescriptorGraphVertexContext {
  59.         private final Set<Class<? extends StateDescriptor>> handledStateDescriptors = new HashSet<>();
  60.         Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes;
  61.         MultiMap<Class<? extends StateDescriptor>, Class<? extends StateDescriptor>> inferredForwardEdges;

  62.         private HandleStateDescriptorGraphVertexContext(
  63.                         Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes,
  64.                         MultiMap<Class<? extends StateDescriptor>, Class<? extends StateDescriptor>> inferredForwardEdges) {
  65.             this.graphVertexes = graphVertexes;
  66.             this.inferredForwardEdges = inferredForwardEdges;
  67.         }

  68.         private boolean recurseInto(Class<? extends StateDescriptor> stateDescriptorClass) {
  69.             boolean wasAdded = handledStateDescriptors.add(stateDescriptorClass);
  70.             boolean alreadyHandled = !wasAdded;
  71.             return alreadyHandled;
  72.         }

  73.         private GraphVertex<StateDescriptor> getOrConstruct(Class<? extends StateDescriptor> stateDescriptorClass)
  74.                         throws InstantiationException, IllegalAccessException, IllegalArgumentException,
  75.                         InvocationTargetException, NoSuchMethodException, SecurityException {
  76.             GraphVertex<StateDescriptor> graphVertexStateDescriptor = graphVertexes.get(stateDescriptorClass);

  77.             if (graphVertexStateDescriptor == null) {
  78.                 graphVertexStateDescriptor = addNewStateDescriptorGraphVertex(stateDescriptorClass, graphVertexes);

  79.                 for (Class<? extends StateDescriptor> inferredSuccessor : inferredForwardEdges.getAll(
  80.                                 stateDescriptorClass)) {
  81.                     graphVertexStateDescriptor.getElement().addSuccessor(inferredSuccessor);
  82.                 }
  83.             }

  84.             return graphVertexStateDescriptor;
  85.         }
  86.     }

  87.     private static void handleStateDescriptorGraphVertex(GraphVertex<StateDescriptor> node,
  88.                     HandleStateDescriptorGraphVertexContext context,
  89.                     boolean failOnUnknownStates)
  90.                     throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException {
  91.         Class<? extends StateDescriptor> stateDescriptorClass = node.element.getClass();
  92.         boolean alreadyHandled = context.recurseInto(stateDescriptorClass);
  93.         if (alreadyHandled) {
  94.             return;
  95.         }

  96.         Set<Class<? extends StateDescriptor>> successorClasses = node.element.getSuccessors();
  97.         int numSuccessors = successorClasses.size();

  98.         Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> successorStateDescriptors = new HashMap<>(
  99.                         numSuccessors);
  100.         for (Class<? extends StateDescriptor> successorClass : successorClasses) {
  101.             GraphVertex<StateDescriptor> successorGraphNode = context.getOrConstruct(successorClass);
  102.             successorStateDescriptors.put(successorClass, successorGraphNode);
  103.         }

  104.         switch (numSuccessors) {
  105.         case 0:
  106.             throw new IllegalStateException("State " + stateDescriptorClass + " has no successor");
  107.         case 1:
  108.             GraphVertex<StateDescriptor> soleSuccessorNode = successorStateDescriptors.values().iterator().next();
  109.             node.addOutgoingEdge(soleSuccessorNode);
  110.             handleStateDescriptorGraphVertex(soleSuccessorNode, context, failOnUnknownStates);
  111.             return;
  112.         }

  113.         // We hit a state with multiple successors, perform a topological sort on the successors first.
  114.         // Process the information regarding subordinates and superiors states.

  115.         // The preference graph is the graph where the precedence information of all successors is stored, which we will
  116.         // topologically sort to find out which successor we should try first. It is a further new graph we use solely in
  117.         // this step for every node. The graph is represented as map. There is no special marker for the initial node
  118.         // as it is not required for the topological sort performed later.
  119.         Map<Class<? extends StateDescriptor>, GraphVertex<Class<? extends StateDescriptor>>> preferenceGraph = new HashMap<>(numSuccessors);

  120.         // Iterate over all successor states of the current state.
  121.         for (GraphVertex<StateDescriptor> successorStateDescriptorGraphNode : successorStateDescriptors.values()) {
  122.             StateDescriptor successorStateDescriptor = successorStateDescriptorGraphNode.element;
  123.             Class<? extends StateDescriptor> successorStateDescriptorClass = successorStateDescriptor.getClass();
  124.             for (Class<? extends StateDescriptor> subordinateClass : successorStateDescriptor.getSubordinates()) {
  125.                 if (failOnUnknownStates && !successorClasses.contains(subordinateClass)) {
  126.                     throw new IllegalStateException(successorStateDescriptor + " points to a subordinate '" + subordinateClass + "' which is not part of the successor set");
  127.                 }

  128.                 GraphVertex<Class<? extends StateDescriptor>> superiorClassNode = lookupAndCreateIfRequired(
  129.                                 preferenceGraph, successorStateDescriptorClass);
  130.                 GraphVertex<Class<? extends StateDescriptor>> subordinateClassNode = lookupAndCreateIfRequired(
  131.                                 preferenceGraph, subordinateClass);

  132.                 superiorClassNode.addOutgoingEdge(subordinateClassNode);
  133.             }
  134.             for (Class<? extends StateDescriptor> superiorClass : successorStateDescriptor.getSuperiors()) {
  135.                 if (failOnUnknownStates && !successorClasses.contains(superiorClass)) {
  136.                     throw new IllegalStateException(successorStateDescriptor + " points to a superior '" + superiorClass
  137.                                     + "' which is not part of the successor set");
  138.                 }

  139.                 GraphVertex<Class<? extends StateDescriptor>> subordinateClassNode = lookupAndCreateIfRequired(
  140.                                 preferenceGraph, successorStateDescriptorClass);
  141.                 GraphVertex<Class<? extends StateDescriptor>> superiorClassNode = lookupAndCreateIfRequired(
  142.                                 preferenceGraph, superiorClass);

  143.                 superiorClassNode.addOutgoingEdge(subordinateClassNode);
  144.             }
  145.         }

  146.         // Perform a topological sort which returns the state descriptor classes sorted by their priority. Highest
  147.         // priority state descriptors first.
  148.         List<GraphVertex<Class<? extends StateDescriptor>>> sortedSuccessors = topologicalSort(preferenceGraph.values());

  149.         // Handle the successor nodes which have not preference information available. Simply append them to the end of
  150.         // the sorted successor list.
  151.         outerloop: for (Class<? extends StateDescriptor> successorStateDescriptor : successorClasses) {
  152.             for (GraphVertex<Class<? extends StateDescriptor>> sortedSuccessor : sortedSuccessors) {
  153.                 if (sortedSuccessor.getElement() == successorStateDescriptor) {
  154.                     continue outerloop;
  155.                 }
  156.             }

  157.             sortedSuccessors.add(new GraphVertex<>(successorStateDescriptor));
  158.         }

  159.         for (GraphVertex<Class<? extends StateDescriptor>> successor : sortedSuccessors) {
  160.             GraphVertex<StateDescriptor> successorVertex = successorStateDescriptors.get(successor.element);
  161.             if (successorVertex == null) {
  162.                 // The successor does not exist, probably because its module was not enabled.
  163.                 continue;
  164.             }
  165.             node.addOutgoingEdge(successorVertex);

  166.             // Recurse further.
  167.             handleStateDescriptorGraphVertex(successorVertex, context, failOnUnknownStates);
  168.         }
  169.     }

  170.     public static GraphVertex<StateDescriptor> constructStateDescriptorGraph(
  171.                     Set<Class<? extends StateDescriptor>> backwardEdgeStateDescriptors,
  172.                     boolean failOnUnknownStates)
  173.                     throws InstantiationException, IllegalAccessException, IllegalArgumentException,
  174.                     InvocationTargetException, NoSuchMethodException, SecurityException {
  175.         Map<Class<? extends StateDescriptor>, GraphVertex<StateDescriptor>> graphVertexes = new HashMap<>();

  176.         final Class<? extends StateDescriptor> initialStatedescriptorClass = DisconnectedStateDescriptor.class;
  177.         GraphVertex<StateDescriptor> initialNode = addNewStateDescriptorGraphVertex(initialStatedescriptorClass, graphVertexes);

  178.         MultiMap<Class<? extends StateDescriptor>, Class<? extends StateDescriptor>> inferredForwardEdges = new MultiMap<>();
  179.         for (Class<? extends StateDescriptor> backwardsEdge : backwardEdgeStateDescriptors) {
  180.             GraphVertex<StateDescriptor> graphVertexStateDescriptor = addNewStateDescriptorGraphVertex(backwardsEdge, graphVertexes);

  181.             for (Class<? extends StateDescriptor> predecessor : graphVertexStateDescriptor.getElement().getPredeccessors()) {
  182.                 inferredForwardEdges.put(predecessor, backwardsEdge);
  183.             }
  184.         }
  185.         // Ensure that the initial node has their successors inferred.
  186.         for (Class<? extends StateDescriptor> inferredSuccessorOfInitialStateDescriptor : inferredForwardEdges.getAll(initialStatedescriptorClass)) {
  187.             initialNode.getElement().addSuccessor(inferredSuccessorOfInitialStateDescriptor);
  188.         }

  189.         HandleStateDescriptorGraphVertexContext context = new HandleStateDescriptorGraphVertexContext(graphVertexes, inferredForwardEdges);
  190.         handleStateDescriptorGraphVertex(initialNode, context, failOnUnknownStates);

  191.         return initialNode;
  192.     }

  193.     private static GraphVertex<State> convertToStateGraph(GraphVertex<StateDescriptor> stateDescriptorVertex,
  194.                     ModularXmppClientToServerConnectionInternal connectionInternal, Map<StateDescriptor, GraphVertex<State>> handledStateDescriptors) {
  195.         StateDescriptor stateDescriptor = stateDescriptorVertex.getElement();
  196.         GraphVertex<State> stateVertex = handledStateDescriptors.get(stateDescriptor);
  197.         if (stateVertex != null) {
  198.             return stateVertex;
  199.         }

  200.         State state = stateDescriptor.constructState(connectionInternal);
  201.         stateVertex = new GraphVertex<>(state);
  202.         handledStateDescriptors.put(stateDescriptor, stateVertex);
  203.         for (GraphVertex<StateDescriptor> successorStateDescriptorVertex : stateDescriptorVertex.getOutgoingEdges()) {
  204.             GraphVertex<State> successorStateVertex = convertToStateGraph(successorStateDescriptorVertex, connectionInternal, handledStateDescriptors);
  205.             // It is important that we keep the order of the edges. This should do it.
  206.             stateVertex.addOutgoingEdge(successorStateVertex);
  207.         }

  208.         return stateVertex;
  209.     }

  210.     public static GraphVertex<State> convertToStateGraph(GraphVertex<StateDescriptor> initialStateDescriptor,
  211.                     ModularXmppClientToServerConnectionInternal connectionInternal) {
  212.         Map<StateDescriptor, GraphVertex<State>> handledStateDescriptors = new HashMap<>();
  213.         GraphVertex<State> initialState = convertToStateGraph(initialStateDescriptor, connectionInternal,
  214.                         handledStateDescriptors);
  215.         return initialState;
  216.     }

  217.     // Graph API after here.
  218.     // This API could possibly factored out into an extra package/class, but then we will probably need a builder for
  219.     // the graph vertex in order to keep it immutable.
  220.     public static final class GraphVertex<E> {
  221.         private final E element;
  222.         private final List<GraphVertex<E>> outgoingEdges = new ArrayList<>();

  223.         private VertexColor color = VertexColor.white;

  224.         private GraphVertex(E element) {
  225.             this.element = element;
  226.         }

  227.         private void addOutgoingEdge(GraphVertex<E> vertex) {
  228.             assert vertex != null;
  229.             if (outgoingEdges.contains(vertex)) {
  230.                 throw new IllegalArgumentException("This " + this + " already has an outgoing edge to " + vertex);
  231.             }
  232.             outgoingEdges.add(vertex);
  233.         }

  234.         public E getElement() {
  235.             return element;
  236.         }

  237.         public List<GraphVertex<E>> getOutgoingEdges() {
  238.             return Collections.unmodifiableList(outgoingEdges);
  239.         }

  240.         private enum VertexColor {
  241.             white,
  242.             grey,
  243.             black,
  244.         }

  245.         @Override
  246.         public String toString() {
  247.             return toString(true);
  248.         }

  249.         public String toString(boolean includeOutgoingEdges) {
  250.             StringBuilder sb = new StringBuilder();
  251.             sb.append("GraphVertex " + element + " [color=" + color
  252.                             + ", identityHashCode=" + System.identityHashCode(this)
  253.                             + ", outgoingEdgeCount=" + outgoingEdges.size());

  254.             if (includeOutgoingEdges) {
  255.                 sb.append(", outgoingEdges={");

  256.                 for (Iterator<GraphVertex<E>> it = outgoingEdges.iterator(); it.hasNext();) {
  257.                     GraphVertex<E> outgoingEdgeVertex = it.next();
  258.                     sb.append(outgoingEdgeVertex.toString(false));
  259.                     if (it.hasNext()) {
  260.                         sb.append(", ");
  261.                     }
  262.                 }
  263.                 sb.append('}');
  264.             }

  265.             sb.append(']');
  266.             return sb.toString();
  267.         }
  268.     }

  269.     private static GraphVertex<Class<? extends StateDescriptor>> lookupAndCreateIfRequired(
  270.                     Map<Class<? extends StateDescriptor>, GraphVertex<Class<? extends StateDescriptor>>> map,
  271.                     Class<? extends StateDescriptor> clazz) {
  272.         GraphVertex<Class<? extends StateDescriptor>> vertex = map.get(clazz);
  273.         if (vertex == null) {
  274.             vertex = new GraphVertex<>(clazz);
  275.             map.put(clazz, vertex);
  276.         }
  277.         return vertex;
  278.     }

  279.     private static <E> List<GraphVertex<E>> topologicalSort(Collection<GraphVertex<E>> vertexes) {
  280.         List<GraphVertex<E>> res = new ArrayList<>();
  281.         dfs(vertexes, vertex -> res.add(0, vertex), null);
  282.         return res;
  283.     }

  284.     private static <E> void dfsVisit(GraphVertex<E> vertex, Consumer<GraphVertex<E>> dfsFinishedVertex,
  285.                     DfsEdgeFound<E> dfsEdgeFound) {
  286.         vertex.color = GraphVertex.VertexColor.grey;

  287.         final int totalEdgeCount = vertex.getOutgoingEdges().size();

  288.         int edgeCount = 0;

  289.         for (GraphVertex<E> successorVertex : vertex.getOutgoingEdges()) {
  290.             edgeCount++;
  291.             if (dfsEdgeFound != null) {
  292.                 dfsEdgeFound.onEdgeFound(vertex, successorVertex, edgeCount, totalEdgeCount);
  293.             }
  294.             if (successorVertex.color == GraphVertex.VertexColor.white) {
  295.                 dfsVisit(successorVertex, dfsFinishedVertex, dfsEdgeFound);
  296.             }
  297.         }

  298.         vertex.color = GraphVertex.VertexColor.black;
  299.         if (dfsFinishedVertex != null) {
  300.             dfsFinishedVertex.accept(vertex);
  301.         }
  302.     }

  303.     private static <E> void dfs(Collection<GraphVertex<E>> vertexes, Consumer<GraphVertex<E>> dfsFinishedVertex,
  304.                     DfsEdgeFound<E> dfsEdgeFound) {
  305.         for (GraphVertex<E> vertex : vertexes) {
  306.             if (vertex.color == GraphVertex.VertexColor.white) {
  307.                 dfsVisit(vertex, dfsFinishedVertex, dfsEdgeFound);
  308.             }
  309.         }
  310.     }

  311.     public static <E> void stateDescriptorGraphToDot(Collection<GraphVertex<StateDescriptor>> vertexes,
  312.                     PrintWriter dotOut, boolean breakStateName) {
  313.         dotOut.append("digraph {\n");
  314.         dfs(vertexes,
  315.                 finishedVertex -> {
  316.                    boolean isMultiVisitState = finishedVertex.element.isMultiVisitState();
  317.                    boolean isFinalState = finishedVertex.element.isFinalState();
  318.                    boolean isNotImplemented = finishedVertex.element.isNotImplemented();

  319.                    String style = null;
  320.                    if (isMultiVisitState) {
  321.                        style = "bold";
  322.                    } else if (isFinalState) {
  323.                        style = "filled";
  324.                    } else if (isNotImplemented) {
  325.                        style = "dashed";
  326.                    }

  327.                    if (style == null) {
  328.                        return;
  329.                    }

  330.                    dotOut.append('"')
  331.                        .append(finishedVertex.element.getFullStateName(breakStateName))
  332.                        .append("\" [ ")
  333.                        .append("style=")
  334.                        .append(style)
  335.                        .append(" ]\n");
  336.                },
  337.                (from, to, edgeId, totalEdgeCount) -> {
  338.                    dotOut.append("  \"")
  339.                        .append(from.element.getFullStateName(breakStateName))
  340.                        .append("\" -> \"")
  341.                        .append(to.element.getFullStateName(breakStateName))
  342.                        .append('"');
  343.                    if (totalEdgeCount > 1) {
  344.                        // Note that 'dot' requires *double* quotes to enclose the value.
  345.                        dotOut.append(" [xlabel=\"")
  346.                        .append(Integer.toString(edgeId))
  347.                        .append("\"]");
  348.                    }
  349.                    dotOut.append(";\n");
  350.                });
  351.         dotOut.append("}\n");
  352.     }

  353.     private interface DfsEdgeFound<E> {
  354.         void onEdgeFound(GraphVertex<E> from, GraphVertex<E> to, int edgeId, int totalEdgeCount);
  355.     }
  356. }