ScramMechanism.java

  1. /**
  2.  *
  3.  * Copyright 2014-2021 Florian Schmaus
  4.  *
  5.  * Licensed under the Apache License, Version 2.0 (the "License");
  6.  * you may not use this file except in compliance with the License.
  7.  * You may obtain a copy of the License at
  8.  *
  9.  *     http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.jivesoftware.smack.sasl.core;

  18. import java.nio.charset.StandardCharsets;
  19. import java.security.InvalidKeyException;
  20. import java.security.SecureRandom;
  21. import java.util.Collections;
  22. import java.util.HashMap;
  23. import java.util.Map;
  24. import java.util.Random;

  25. import javax.security.auth.callback.CallbackHandler;

  26. import org.jivesoftware.smack.SmackException.SmackSaslException;
  27. import org.jivesoftware.smack.sasl.SASLMechanism;
  28. import org.jivesoftware.smack.util.ByteUtils;
  29. import org.jivesoftware.smack.util.SHA1;
  30. import org.jivesoftware.smack.util.StringUtils;
  31. import org.jivesoftware.smack.util.stringencoder.Base64;

  32. import org.jxmpp.util.cache.Cache;
  33. import org.jxmpp.util.cache.LruCache;

  34. public abstract class ScramMechanism extends SASLMechanism {

  35.     private static final int RANDOM_ASCII_BYTE_COUNT = 32;
  36.     private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
  37.     private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
  38.     private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };

  39.     private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() {
  40.         @Override
  41.         protected SecureRandom initialValue() {
  42.             return new SecureRandom();
  43.         }
  44.     };

  45.     private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);

  46.     private final ScramHmac scramHmac;

  47.     protected ScramMechanism(ScramHmac scramHmac) {
  48.         this.scramHmac = scramHmac;
  49.     }

  50.     private enum State {
  51.         INITIAL,
  52.         AUTH_TEXT_SENT,
  53.         RESPONSE_SENT,
  54.         VALID_SERVER_RESPONSE,
  55.     }

  56.     /**
  57.      * The state of the this instance of SASL SCRAM-SHA1 authentication.
  58.      */
  59.     private State state = State.INITIAL;

  60.     /**
  61.      * The client's random ASCII which is used as nonce
  62.      */
  63.     private String clientRandomAscii;

  64.     private String clientFirstMessageBare;
  65.     private byte[] serverSignature;

  66.     @Override
  67.     protected void authenticateInternal(CallbackHandler cbh) {
  68.         throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
  69.     }

  70.     @Override
  71.     protected byte[] getAuthenticationText() {
  72.         clientRandomAscii = getRandomAscii();
  73.         String saslPrepedAuthcId = saslPrep(authenticationId);
  74.         clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
  75.         String clientFirstMessage = getGS2Header() + clientFirstMessageBare;
  76.         state = State.AUTH_TEXT_SENT;
  77.         return toBytes(clientFirstMessage);
  78.     }

  79.     @Override
  80.     public String getName() {
  81.         String name = "SCRAM-" + scramHmac.getHmacName();
  82.         return name;
  83.     }

  84.     @Override
  85.     public void checkIfSuccessfulOrThrow() throws SmackSaslException {
  86.         if (state != State.VALID_SERVER_RESPONSE) {
  87.             throw new SmackSaslException("SCRAM-SHA1 is missing valid server response");
  88.         }
  89.     }

  90.     @Override
  91.     public boolean authzidSupported() {
  92.         return true;
  93.     }

  94.     @Override
  95.     protected byte[] evaluateChallenge(byte[] challenge) throws SmackSaslException {
  96.         // TODO: Where is it specified that this is an UTF-8 encoded string?
  97.         String challengeString = new String(challenge, StandardCharsets.UTF_8);

  98.         switch (state) {
  99.         case AUTH_TEXT_SENT:
  100.             final String serverFirstMessage = challengeString;
  101.             Map<Character, String> attributes = parseAttributes(challengeString);

  102.             // Handle server random ASCII (nonce)
  103.             String rvalue = attributes.get('r');
  104.             if (rvalue == null) {
  105.                 throw new SmackSaslException("Server random ASCII is null");
  106.             }
  107.             if (rvalue.length() <= clientRandomAscii.length()) {
  108.                 throw new SmackSaslException("Server random ASCII is shorter then client random ASCII");
  109.             }
  110.             String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
  111.             if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
  112.                 throw new SmackSaslException("Received client random ASCII does not match client random ASCII");
  113.             }

  114.             // Handle iterations
  115.             int iterations;
  116.             String iterationsString = attributes.get('i');
  117.             if (iterationsString == null) {
  118.                 throw new SmackSaslException("Iterations attribute not set");
  119.             }
  120.             try {
  121.                 iterations = Integer.parseInt(iterationsString);
  122.             }
  123.             catch (NumberFormatException e) {
  124.                 throw new SmackSaslException("Exception parsing iterations", e);
  125.             }

  126.             // Handle salt
  127.             String salt = attributes.get('s');
  128.             if (salt == null) {
  129.                 throw new SmackSaslException("SALT not send");
  130.             }

  131.             // Parsing and error checking is done, we can now begin to calculate the values

  132.             // First the client-final-message-without-proof
  133.             String channelBinding = "c=" + Base64.encodeToString(getCBindInput());
  134.             String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue;

  135.             // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
  136.             // client-final-message-without-proof
  137.             byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
  138.                             + clientFinalMessageWithoutProof);

  139.             // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
  140.             // as it is likely that the server is going to advertise the same salt value upon reauthentication."
  141.             // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple
  142.             // mechanisms.
  143.             final String cacheKey = password + ',' + salt + ',' + getName();
  144.             byte[] serverKey, clientKey;
  145.             Keys keys = CACHE.lookup(cacheKey);
  146.             if (keys == null) {
  147.                 // SaltedPassword := Hi(Normalize(password), salt, i)
  148.                 byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);

  149.                 // ServerKey := HMAC(SaltedPassword, "Server Key")
  150.                 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);

  151.                 // ClientKey := HMAC(SaltedPassword, "Client Key")
  152.                 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);

  153.                 keys = new Keys(clientKey, serverKey);
  154.                 CACHE.put(cacheKey, keys);
  155.             }
  156.             else {
  157.                 serverKey = keys.serverKey;
  158.                 clientKey = keys.clientKey;
  159.             }

  160.             // ServerSignature := HMAC(ServerKey, AuthMessage)
  161.             serverSignature = hmac(serverKey, authMessage);

  162.             // StoredKey := H(ClientKey)
  163.             byte[] storedKey = SHA1.bytes(clientKey);

  164.             // ClientSignature := HMAC(StoredKey, AuthMessage)
  165.             byte[] clientSignature = hmac(storedKey, authMessage);

  166.             // ClientProof := ClientKey XOR ClientSignature
  167.             byte[] clientProof = new byte[clientKey.length];
  168.             for (int i = 0; i < clientProof.length; i++) {
  169.                 clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
  170.             }

  171.             String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
  172.             state = State.RESPONSE_SENT;
  173.             return toBytes(clientFinalMessage);
  174.         case RESPONSE_SENT:
  175.             String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
  176.             if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
  177.                 throw new SmackSaslException("Server final message does not match calculated one");
  178.             }
  179.             state = State.VALID_SERVER_RESPONSE;
  180.             break;
  181.         default:
  182.             throw new SmackSaslException("Invalid state");
  183.         }
  184.         return null;
  185.     }

  186.     private String getGS2Header() {
  187.         String authzidPortion = "";
  188.         if (authorizationId != null) {
  189.             authzidPortion = "a=" + authorizationId;
  190.         }

  191.         String cbName = getGs2CbindFlag();
  192.         assert StringUtils.isNotEmpty(cbName);

  193.         return cbName + ',' + authzidPortion + ",";
  194.     }

  195.     private byte[] getCBindInput() throws SmackSaslException {
  196.         byte[] cbindData = getChannelBindingData();
  197.         byte[] gs2Header = toBytes(getGS2Header());

  198.         if (cbindData == null) {
  199.             return gs2Header;
  200.         }

  201.         return ByteUtils.concat(gs2Header, cbindData);
  202.     }

  203.     /**
  204.      * Get the SCRAM GSS-API Channel Binding Flag value.
  205.      *
  206.      * @return the gs2-cbind-flag value.
  207.      * @see <a href="https://tools.ietf.org/html/rfc5802#section-6">RFC 5802 § 6.</a>
  208.      */
  209.     protected String getGs2CbindFlag() {
  210.         // Check if we are using TLS and if a "-PLUS" variant of this mechanism is enabled. Assuming that the "-PLUS"
  211.         // variants always have precedence before the non-"-PLUS" variants this means that the server did not announce
  212.         // the "-PLUS" variant, as otherwise we would have tried it.
  213.         if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) {
  214.             // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we
  215.             // believe the server does not.
  216.             return "y";
  217.         }
  218.         return "n";
  219.     }

  220.     /**
  221.      * Get the channel binding data.
  222.      *
  223.      * @return the Channel Binding data.
  224.      * @throws SmackSaslException if a SASL specific error occurred.
  225.      */
  226.     protected byte[] getChannelBindingData() throws SmackSaslException {
  227.         return null;
  228.     }

  229.     private static Map<Character, String> parseAttributes(String string) throws SmackSaslException {
  230.         if (string.length() == 0) {
  231.             return Collections.emptyMap();
  232.         }

  233.         String[] keyValuePairs = string.split(",");
  234.         Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
  235.         for (String keyValuePair : keyValuePairs) {
  236.             if (keyValuePair.length() < 3) {
  237.                 throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
  238.             }
  239.             char key = keyValuePair.charAt(0);
  240.             if (keyValuePair.charAt(1) != '=') {
  241.                 throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
  242.             }
  243.             String value = keyValuePair.substring(2);
  244.             res.put(key, value);
  245.         }

  246.         return res;
  247.     }

  248.     /**
  249.      * Generate random ASCII.
  250.      * <p>
  251.      * This method is non-static and package-private for unit testing purposes.
  252.      * </p>
  253.      * @return A String of 32 random printable ASCII characters.
  254.      */
  255.     String getRandomAscii() {
  256.         int count = 0;
  257.         char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
  258.         final Random random = SECURE_RANDOM.get();
  259.         while (count < RANDOM_ASCII_BYTE_COUNT) {
  260.             int r = random.nextInt(128);
  261.             char c = (char) r;
  262.             // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
  263.             if (!isPrintableNonCommaAsciiChar(c)) {
  264.                 continue;
  265.             }
  266.             randomAscii[count++] = c;
  267.         }
  268.         return new String(randomAscii);
  269.     }

  270.     private static boolean isPrintableNonCommaAsciiChar(char c) {
  271.         if (c == ',') {
  272.             return false;
  273.         }
  274.         // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126)
  275.         // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to
  276.         // ensure that c is within [33, 126].
  277.         return c > 32 && c < 127;
  278.     }

  279.     /**
  280.      * Escapes usernames or passwords for SASL SCRAM-SHA1.
  281.      * <p>
  282.      * According to RFC 5802 § 5.1 'n:'
  283.      * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
  284.      * </p>
  285.      *
  286.      * @param string TODO javadoc me please
  287.      * @return the escaped string
  288.      */
  289.     private static String escape(String string) {
  290.         StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
  291.         for (int i = 0; i < string.length(); i++) {
  292.             char c = string.charAt(i);
  293.             switch (c) {
  294.             case ',':
  295.                 sb.append("=2C");
  296.                 break;
  297.             case '=':
  298.                 sb.append("=3D");
  299.                 break;
  300.             default:
  301.                 sb.append(c);
  302.                 break;
  303.             }
  304.         }
  305.         return sb.toString();
  306.     }

  307.     /**
  308.      * RFC 5802 § 2.2 HMAC(key, str)
  309.      *
  310.      * @param key TODO javadoc me please
  311.      * @param str TODO javadoc me please
  312.      * @return the HMAC-SHA1 value of the input.
  313.      * @throws SmackSaslException if Smack detected an exceptional situation.
  314.      */
  315.     private byte[] hmac(byte[] key, byte[] str) throws SmackSaslException {
  316.         try {
  317.             return scramHmac.hmac(key, str);
  318.         }
  319.         catch (InvalidKeyException e) {
  320.             throw new SmackSaslException(getName() + " Exception", e);
  321.         }
  322.     }

  323.     /**
  324.      * RFC 5802 § 2.2 Hi(str, salt, i)
  325.      * <p>
  326.      * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
  327.      * (PRF) and with dkLen == output length of HMAC() == output length of H().
  328.      * </p>
  329.      *
  330.      * @param normalizedPassword the normalized password.
  331.      * @param salt TODO javadoc me please
  332.      * @param iterations TODO javadoc me please
  333.      * @return the result of the Hi function.
  334.      * @throws SmackSaslException if a SASL related error occurs.
  335.      */
  336.     private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackSaslException {
  337.         // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8.
  338.         byte[] key = normalizedPassword.getBytes(StandardCharsets.UTF_8);

  339.         // U1 := HMAC(str, salt + INT(1))
  340.         byte[] u = hmac(key, ByteUtils.concat(salt, ONE));
  341.         byte[] res = u.clone();
  342.         for (int i = 1; i < iterations; i++) {
  343.             u = hmac(key, u);
  344.             for (int j = 0; j < u.length; j++) {
  345.                 res[j] ^= u[j];
  346.             }
  347.         }
  348.         return res;
  349.     }

  350.     private static class Keys {
  351.         private final byte[] clientKey;
  352.         private final byte[] serverKey;

  353.         Keys(byte[] clientKey, byte[] serverKey) {
  354.             this.clientKey = clientKey;
  355.             this.serverKey = serverKey;
  356.         }
  357.     }
  358. }