001/**
002 *
003 * Copyright 2014 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.sasl.core;
018
019import java.security.InvalidKeyException;
020import java.security.SecureRandom;
021import java.util.Collections;
022import java.util.HashMap;
023import java.util.Map;
024
025import javax.security.auth.callback.CallbackHandler;
026
027import org.jivesoftware.smack.SmackException;
028import org.jivesoftware.smack.sasl.SASLMechanism;
029import org.jivesoftware.smack.util.ByteUtils;
030import org.jivesoftware.smack.util.MAC;
031import org.jivesoftware.smack.util.SHA1;
032import org.jivesoftware.smack.util.stringencoder.Base64;
033import org.jxmpp.util.cache.Cache;
034import org.jxmpp.util.cache.LruCache;
035
036public class SCRAMSHA1Mechanism extends SASLMechanism {
037
038    public static final String NAME = "SCRAM-SHA-1";
039
040    private static final int RANDOM_ASCII_BYTE_COUNT = 32;
041    private static final String DEFAULT_GS2_HEADER = "n,,";
042    private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
043    private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
044    private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
045
046    private static final SecureRandom RANDOM = new SecureRandom();
047
048    private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
049
050    private enum State {
051        INITIAL,
052        AUTH_TEXT_SENT,
053        RESPONSE_SENT,
054        VALID_SERVER_RESPONSE,
055    }
056
057    /**
058     * The state of the this instance of SASL SCRAM-SHA1 authentication.
059     */
060    private State state = State.INITIAL;
061
062    /**
063     * The client's random ASCII which is used as nonce
064     */
065    private String clientRandomAscii;
066
067    private String clientFirstMessageBare;
068    private byte[] serverSignature;
069
070    @Override
071    protected void authenticateInternal(CallbackHandler cbh) throws SmackException {
072        throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
073    }
074
075    @Override
076    protected byte[] getAuthenticationText() throws SmackException {
077        clientRandomAscii = getRandomAscii();
078        String saslPrepedAuthcId = saslPrep(authenticationId);
079        clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
080        String clientFirstMessage = DEFAULT_GS2_HEADER + clientFirstMessageBare;
081        state = State.AUTH_TEXT_SENT;
082        return toBytes(clientFirstMessage);
083    }
084
085    @Override
086    public String getName() {
087        return NAME;
088    }
089
090    @Override
091    public int getPriority() {
092        return 110;
093    }
094
095    @Override
096    public SCRAMSHA1Mechanism newInstance() {
097        return new SCRAMSHA1Mechanism();
098    }
099
100
101    @Override
102    public void checkIfSuccessfulOrThrow() throws SmackException {
103        if (state != State.VALID_SERVER_RESPONSE) {
104            throw new SmackException("SCRAM-SHA1 is missing valid server response");
105        }
106    }
107
108    @Override
109    protected byte[] evaluateChallenge(byte[] challenge) throws SmackException {
110        final String challengeString = new String(challenge);
111        switch (state) {
112        case AUTH_TEXT_SENT:
113            final String serverFirstMessage = challengeString;
114            Map<Character, String> attributes = parseAttributes(challengeString);
115
116            // Handle server random ASCII (nonce)
117            String rvalue = attributes.get('r');
118            if (rvalue == null) {
119                throw new SmackException("Server random ASCII is null");
120            }
121            if (rvalue.length() <= clientRandomAscii.length()) {
122                throw new SmackException("Server random ASCII is shorter then client random ASCII");
123            }
124            String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
125            if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
126                throw new SmackException("Received client random ASCII does not match client random ASCII");
127            }
128
129            // Handle iterations
130            int iterations;
131            String iterationsString = attributes.get('i');
132            if (iterationsString == null) {
133                throw new SmackException("Iterations attribute not set");
134            }
135            try {
136                iterations = Integer.parseInt(iterationsString);
137            }
138            catch (NumberFormatException e) {
139                throw new SmackException("Exception parsing iterations", e);
140            }
141
142            // Handle salt
143            String salt = attributes.get('s');
144            if (salt == null) {
145                throw new SmackException("SALT not send");
146            }
147
148            // Parsing and error checking is done, we can now begin to calculate the values
149
150            // First the client-final-message-without-proof
151            String clientFinalMessageWithoutProof = "c=" + Base64.encode(DEFAULT_GS2_HEADER) + ",r=" + rvalue;
152
153            // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
154            // client-final-message-without-proof
155            byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
156                            + clientFinalMessageWithoutProof);
157
158            // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
159            // as it is likely that the server is going to advertise the same salt value upon reauthentication."
160            final String cacheKey = password + ',' + salt;
161            byte[] serverKey, clientKey;
162            Keys keys = CACHE.get(cacheKey);
163            if (keys == null) {
164                // SaltedPassword := Hi(Normalize(password), salt, i)
165                byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
166
167                // ServerKey := HMAC(SaltedPassword, "Server Key")
168                serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
169
170                // ClientKey := HMAC(SaltedPassword, "Client Key")
171                clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
172
173                keys = new Keys(clientKey, serverKey);
174                CACHE.put(cacheKey, keys);
175            }
176            else {
177                serverKey = keys.serverKey;
178                clientKey = keys.clientKey;
179            }
180
181            // ServerSignature := HMAC(ServerKey, AuthMessage)
182            serverSignature = hmac(serverKey, authMessage);
183
184            // StoredKey := H(ClientKey)
185            byte[] storedKey = SHA1.bytes(clientKey);
186
187            // ClientSignature := HMAC(StoredKey, AuthMessage)
188            byte[] clientSignature = hmac(storedKey, authMessage);
189
190            // ClientProof := ClientKey XOR ClientSignature
191            byte[] clientProof = new byte[clientKey.length];
192            for (int i = 0; i < clientProof.length; i++) {
193                clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
194            }
195
196            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
197            state = State.RESPONSE_SENT;
198            return toBytes(clientFinalMessage);
199        case RESPONSE_SENT:
200            String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
201            if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
202                throw new SmackException("Server final message does not match calculated one");
203            }
204            state = State.VALID_SERVER_RESPONSE;
205            break;
206        default:
207            throw new SmackException("Invalid state");
208        }
209        return null;
210    }
211
212    private static Map<Character, String> parseAttributes(String string) throws SmackException {
213        if (string.length() == 0) {
214            return Collections.emptyMap();
215        }
216
217        String[] keyValuePairs = string.split(",");
218        Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
219        for (String keyValuePair : keyValuePairs) {
220            if (keyValuePair.length() < 3) {
221                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
222            }
223            char key = keyValuePair.charAt(0);
224            if (keyValuePair.charAt(1) != '=') {
225                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
226            }
227            String value = keyValuePair.substring(2);
228            res.put(key, value);
229        }
230
231        return res;
232    }
233
234    /**
235     * Generate random ASCII.
236     * <p>
237     * This method is non-static and package-private for unit testing purposes.
238     * </p>
239     * @return A String of 32 random printable ASCII characters.
240     */
241    String getRandomAscii() {
242        int count = 0;
243        char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
244        while (count < RANDOM_ASCII_BYTE_COUNT) {
245            int r = RANDOM.nextInt(128);
246            char c = (char) r;
247            // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
248            if (!isPrintableNonCommaAsciiChar(c)) {
249                continue;
250            }
251            randomAscii[count++] = c;
252        }
253        return new String(randomAscii);
254    }
255
256    private static boolean isPrintableNonCommaAsciiChar(char c) {
257        if (c == ',') {
258            return false;
259        }
260        return c >= 32 && c < 127;
261    }
262
263    /**
264     * Escapes usernames or passwords for SASL SCRAM-SHA1.
265     * <p>
266     * According to RFC 5802 § 5.1 'n:'
267     * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
268     * </p>
269     *
270     * @param string
271     * @return the escaped string
272     */
273    private static String escape(String string) {
274        StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
275        for (int i = 0; i < string.length(); i++) {
276            char c = string.charAt(i);
277            switch (c) {
278            case ',':
279                sb.append("=2C");
280                break;
281            case '=':
282                sb.append("=3D");
283                break;
284            default:
285                sb.append(c);
286                break;
287            }
288        }
289        return sb.toString();
290    }
291
292    /**
293     * RFC 5802 § 2.2 HMAC(key, str)
294     * 
295     * @param key
296     * @param str
297     * @return the HMAC-SHA1 value of the input.
298     * @throws SmackException 
299     */
300    private static byte[] hmac(byte[] key, byte[] str) throws SmackException {
301        try {
302            return MAC.hmacsha1(key, str);
303        }
304        catch (InvalidKeyException e) {
305            throw new SmackException(NAME + " HMAC-SHA1 Exception", e);
306        }
307    }
308
309    /**
310     * RFC 5802 § 2.2 Hi(str, salt, i)
311     * <p>
312     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
313     * (PRF) and with dkLen == output length of HMAC() == output length of H().
314     * </p>
315     * 
316     * @param str
317     * @param salt
318     * @param iterations
319     * @return the result of the Hi function.
320     * @throws SmackException 
321     */
322    private static byte[] hi(String str, byte[] salt, int iterations) throws SmackException {
323        byte[] key = str.getBytes();
324        // U1 := HMAC(str, salt + INT(1))
325        byte[] u = hmac(key, ByteUtils.concact(salt, ONE));
326        byte[] res = u.clone();
327        for (int i = 1; i < iterations; i++) {
328            u = hmac(key, u);
329            for (int j = 0; j < u.length; j++) {
330                res[j] ^= u[j];
331            }
332        }
333        return res;
334    }
335
336    private static class Keys {
337        private final byte[] clientKey;
338        private final byte[] serverKey;
339
340        public Keys(byte[] clientKey, byte[] serverKey) {
341            this.clientKey = clientKey;
342            this.serverKey = serverKey;
343        }
344    }
345}