SCRAMSHA1Mechanism.java

/**
 *
 * Copyright 2014 Florian Schmaus
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.jivesoftware.smack.sasl.core;

import java.security.InvalidKeyException;
import java.security.SecureRandom;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import javax.security.auth.callback.CallbackHandler;

import org.jivesoftware.smack.SmackException;
import org.jivesoftware.smack.sasl.SASLMechanism;
import org.jivesoftware.smack.util.ByteUtils;
import org.jivesoftware.smack.util.MAC;
import org.jivesoftware.smack.util.SHA1;
import org.jivesoftware.smack.util.stringencoder.Base64;
import org.jxmpp.util.cache.Cache;
import org.jxmpp.util.cache.LruCache;

public class SCRAMSHA1Mechanism extends SASLMechanism {

    public static final String NAME = "SCRAM-SHA-1";

    private static final int RANDOM_ASCII_BYTE_COUNT = 32;
    private static final String DEFAULT_GS2_HEADER = "n,,";
    private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
    private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
    private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };

    private static final SecureRandom RANDOM = new SecureRandom();

    private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);

    private enum State {
        INITIAL,
        AUTH_TEXT_SENT,
        RESPONSE_SENT,
        VALID_SERVER_RESPONSE,
    }

    /**
     * The state of the this instance of SASL SCRAM-SHA1 authentication.
     */
    private State state = State.INITIAL;

    /**
     * The client's random ASCII which is used as nonce
     */
    private String clientRandomAscii;

    private String clientFirstMessageBare;
    private byte[] serverSignature;

    @Override
    protected void authenticateInternal(CallbackHandler cbh) throws SmackException {
        throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
    }

    @Override
    protected byte[] getAuthenticationText() throws SmackException {
        clientRandomAscii = getRandomAscii();
        String saslPrepedAuthcId = saslPrep(authenticationId);
        clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
        String clientFirstMessage = DEFAULT_GS2_HEADER + clientFirstMessageBare;
        state = State.AUTH_TEXT_SENT;
        return toBytes(clientFirstMessage);
    }

    @Override
    public String getName() {
        return NAME;
    }

    @Override
    public int getPriority() {
        return 110;
    }

    @Override
    public SCRAMSHA1Mechanism newInstance() {
        return new SCRAMSHA1Mechanism();
    }


    @Override
    public void checkIfSuccessfulOrThrow() throws SmackException {
        if (state != State.VALID_SERVER_RESPONSE) {
            throw new SmackException("SCRAM-SHA1 is missing valid server response");
        }
    }

    @Override
    protected byte[] evaluateChallenge(byte[] challenge) throws SmackException {
        final String challengeString = new String(challenge);
        switch (state) {
        case AUTH_TEXT_SENT:
            final String serverFirstMessage = challengeString;
            Map<Character, String> attributes = parseAttributes(challengeString);

            // Handle server random ASCII (nonce)
            String rvalue = attributes.get('r');
            if (rvalue == null) {
                throw new SmackException("Server random ASCII is null");
            }
            if (rvalue.length() <= clientRandomAscii.length()) {
                throw new SmackException("Server random ASCII is shorter then client random ASCII");
            }
            String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
            if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
                throw new SmackException("Received client random ASCII does not match client random ASCII");
            }

            // Handle iterations
            int iterations;
            String iterationsString = attributes.get('i');
            if (iterationsString == null) {
                throw new SmackException("Iterations attribute not set");
            }
            try {
                iterations = Integer.parseInt(iterationsString);
            }
            catch (NumberFormatException e) {
                throw new SmackException("Exception parsing iterations", e);
            }

            // Handle salt
            String salt = attributes.get('s');
            if (salt == null) {
                throw new SmackException("SALT not send");
            }

            // Parsing and error checking is done, we can now begin to calculate the values

            // First the client-final-message-without-proof
            String clientFinalMessageWithoutProof = "c=" + Base64.encode(DEFAULT_GS2_HEADER) + ",r=" + rvalue;

            // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
            // client-final-message-without-proof
            byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
                            + clientFinalMessageWithoutProof);

            // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
            // as it is likely that the server is going to advertise the same salt value upon reauthentication."
            final String cacheKey = password + ',' + salt;
            byte[] serverKey, clientKey;
            Keys keys = CACHE.get(cacheKey);
            if (keys == null) {
                // SaltedPassword := Hi(Normalize(password), salt, i)
                byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);

                // ServerKey := HMAC(SaltedPassword, "Server Key")
                serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);

                // ClientKey := HMAC(SaltedPassword, "Client Key")
                clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);

                keys = new Keys(clientKey, serverKey);
                CACHE.put(cacheKey, keys);
            }
            else {
                serverKey = keys.serverKey;
                clientKey = keys.clientKey;
            }

            // ServerSignature := HMAC(ServerKey, AuthMessage)
            serverSignature = hmac(serverKey, authMessage);

            // StoredKey := H(ClientKey)
            byte[] storedKey = SHA1.bytes(clientKey);

            // ClientSignature := HMAC(StoredKey, AuthMessage)
            byte[] clientSignature = hmac(storedKey, authMessage);

            // ClientProof := ClientKey XOR ClientSignature
            byte[] clientProof = new byte[clientKey.length];
            for (int i = 0; i < clientProof.length; i++) {
                clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
            }

            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
            state = State.RESPONSE_SENT;
            return toBytes(clientFinalMessage);
        case RESPONSE_SENT:
            String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
            if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
                throw new SmackException("Server final message does not match calculated one");
            }
            state = State.VALID_SERVER_RESPONSE;
            break;
        default:
            throw new SmackException("Invalid state");
        }
        return null;
    }

    private static Map<Character, String> parseAttributes(String string) throws SmackException {
        if (string.length() == 0) {
            return Collections.emptyMap();
        }

        String[] keyValuePairs = string.split(",");
        Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
        for (String keyValuePair : keyValuePairs) {
            if (keyValuePair.length() < 3) {
                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
            }
            char key = keyValuePair.charAt(0);
            if (keyValuePair.charAt(1) != '=') {
                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
            }
            String value = keyValuePair.substring(2);
            res.put(key, value);
        }

        return res;
    }

    /**
     * Generate random ASCII.
     * <p>
     * This method is non-static and package-private for unit testing purposes.
     * </p>
     * @return A String of 32 random printable ASCII characters.
     */
    String getRandomAscii() {
        int count = 0;
        char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
        while (count < RANDOM_ASCII_BYTE_COUNT) {
            int r = RANDOM.nextInt(128);
            char c = (char) r;
            // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
            if (!isPrintableNonCommaAsciiChar(c)) {
                continue;
            }
            randomAscii[count++] = c;
        }
        return new String(randomAscii);
    }

    private static boolean isPrintableNonCommaAsciiChar(char c) {
        if (c == ',') {
            return false;
        }
        return c >= 32 && c < 127;
    }

    /**
     * Escapes usernames or passwords for SASL SCRAM-SHA1.
     * <p>
     * According to RFC 5802 § 5.1 'n:'
     * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
     * </p>
     *
     * @param string
     * @return the escaped string
     */
    private static String escape(String string) {
        StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
        for (int i = 0; i < string.length(); i++) {
            char c = string.charAt(i);
            switch (c) {
            case ',':
                sb.append("=2C");
                break;
            case '=':
                sb.append("=3D");
                break;
            default:
                sb.append(c);
                break;
            }
        }
        return sb.toString();
    }

    /**
     * RFC 5802 § 2.2 HMAC(key, str)
     * 
     * @param key
     * @param str
     * @return the HMAC-SHA1 value of the input.
     * @throws SmackException 
     */
    private static byte[] hmac(byte[] key, byte[] str) throws SmackException {
        try {
            return MAC.hmacsha1(key, str);
        }
        catch (InvalidKeyException e) {
            throw new SmackException(NAME + " HMAC-SHA1 Exception", e);
        }
    }

    /**
     * RFC 5802 § 2.2 Hi(str, salt, i)
     * <p>
     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
     * (PRF) and with dkLen == output length of HMAC() == output length of H().
     * </p>
     * 
     * @param str
     * @param salt
     * @param iterations
     * @return the result of the Hi function.
     * @throws SmackException 
     */
    private static byte[] hi(String str, byte[] salt, int iterations) throws SmackException {
        byte[] key = str.getBytes();
        // U1 := HMAC(str, salt + INT(1))
        byte[] u = hmac(key, ByteUtils.concact(salt, ONE));
        byte[] res = u.clone();
        for (int i = 1; i < iterations; i++) {
            u = hmac(key, u);
            for (int j = 0; j < u.length; j++) {
                res[j] ^= u[j];
            }
        }
        return res;
    }

    private static class Keys {
        private final byte[] clientKey;
        private final byte[] serverKey;

        public Keys(byte[] clientKey, byte[] serverKey) {
            this.clientKey = clientKey;
            this.serverKey = serverKey;
        }
    }
}