001/**
002 *
003 * Copyright 2014-2017 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.sasl.core;
018
019import java.io.UnsupportedEncodingException;
020import java.security.InvalidKeyException;
021import java.security.SecureRandom;
022import java.util.Collections;
023import java.util.HashMap;
024import java.util.Map;
025import java.util.Random;
026
027import javax.security.auth.callback.CallbackHandler;
028
029import org.jivesoftware.smack.SmackException;
030import org.jivesoftware.smack.sasl.SASLMechanism;
031import org.jivesoftware.smack.util.ByteUtils;
032import org.jivesoftware.smack.util.SHA1;
033import org.jivesoftware.smack.util.StringUtils;
034import org.jivesoftware.smack.util.stringencoder.Base64;
035
036import org.jxmpp.util.cache.Cache;
037import org.jxmpp.util.cache.LruCache;
038
039public abstract class ScramMechanism extends SASLMechanism {
040
041    private static final int RANDOM_ASCII_BYTE_COUNT = 32;
042    private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
043    private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
044    private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
045
046    private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() {
047        @Override
048        protected SecureRandom initialValue() {
049            return new SecureRandom();
050        }
051    };
052
053    private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
054
055    private final ScramHmac scramHmac;
056
057    protected ScramMechanism(ScramHmac scramHmac) {
058        this.scramHmac = scramHmac;
059    }
060
061    private enum State {
062        INITIAL,
063        AUTH_TEXT_SENT,
064        RESPONSE_SENT,
065        VALID_SERVER_RESPONSE,
066    }
067
068    /**
069     * The state of the this instance of SASL SCRAM-SHA1 authentication.
070     */
071    private State state = State.INITIAL;
072
073    /**
074     * The client's random ASCII which is used as nonce
075     */
076    private String clientRandomAscii;
077
078    private String clientFirstMessageBare;
079    private byte[] serverSignature;
080
081    @Override
082    protected void authenticateInternal(CallbackHandler cbh) throws SmackException {
083        throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
084    }
085
086    @Override
087    protected byte[] getAuthenticationText() throws SmackException {
088        clientRandomAscii = getRandomAscii();
089        String saslPrepedAuthcId = saslPrep(authenticationId);
090        clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
091        String clientFirstMessage = getGS2Header() + clientFirstMessageBare;
092        state = State.AUTH_TEXT_SENT;
093        return toBytes(clientFirstMessage);
094    }
095
096    @Override
097    public String getName() {
098        String name = "SCRAM-" + scramHmac.getHmacName();
099        return name;
100    }
101
102    @Override
103    public void checkIfSuccessfulOrThrow() throws SmackException {
104        if (state != State.VALID_SERVER_RESPONSE) {
105            throw new SmackException("SCRAM-SHA1 is missing valid server response");
106        }
107    }
108
109    @Override
110    public boolean authzidSupported() {
111        return true;
112    }
113
114    @Override
115    protected byte[] evaluateChallenge(byte[] challenge) throws SmackException {
116        String challengeString;
117        try {
118            // TODO: Where is it specified that this is an UTF-8 encoded string?
119            challengeString = new String(challenge, StringUtils.UTF8);
120        }
121        catch (UnsupportedEncodingException e) {
122            throw new AssertionError(e);
123        }
124
125        switch (state) {
126        case AUTH_TEXT_SENT:
127            final String serverFirstMessage = challengeString;
128            Map<Character, String> attributes = parseAttributes(challengeString);
129
130            // Handle server random ASCII (nonce)
131            String rvalue = attributes.get('r');
132            if (rvalue == null) {
133                throw new SmackException("Server random ASCII is null");
134            }
135            if (rvalue.length() <= clientRandomAscii.length()) {
136                throw new SmackException("Server random ASCII is shorter then client random ASCII");
137            }
138            String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
139            if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
140                throw new SmackException("Received client random ASCII does not match client random ASCII");
141            }
142
143            // Handle iterations
144            int iterations;
145            String iterationsString = attributes.get('i');
146            if (iterationsString == null) {
147                throw new SmackException("Iterations attribute not set");
148            }
149            try {
150                iterations = Integer.parseInt(iterationsString);
151            }
152            catch (NumberFormatException e) {
153                throw new SmackException("Exception parsing iterations", e);
154            }
155
156            // Handle salt
157            String salt = attributes.get('s');
158            if (salt == null) {
159                throw new SmackException("SALT not send");
160            }
161
162            // Parsing and error checking is done, we can now begin to calculate the values
163
164            // First the client-final-message-without-proof
165            String channelBinding = "c=" + Base64.encodeToString(getCBindInput());
166            String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue;
167
168            // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
169            // client-final-message-without-proof
170            byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
171                            + clientFinalMessageWithoutProof);
172
173            // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
174            // as it is likely that the server is going to advertise the same salt value upon reauthentication."
175            // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple
176            // mechanisms.
177            final String cacheKey = password + ',' + salt + ',' + getName();
178            byte[] serverKey, clientKey;
179            Keys keys = CACHE.lookup(cacheKey);
180            if (keys == null) {
181                // SaltedPassword := Hi(Normalize(password), salt, i)
182                byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
183
184                // ServerKey := HMAC(SaltedPassword, "Server Key")
185                serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
186
187                // ClientKey := HMAC(SaltedPassword, "Client Key")
188                clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
189
190                keys = new Keys(clientKey, serverKey);
191                CACHE.put(cacheKey, keys);
192            }
193            else {
194                serverKey = keys.serverKey;
195                clientKey = keys.clientKey;
196            }
197
198            // ServerSignature := HMAC(ServerKey, AuthMessage)
199            serverSignature = hmac(serverKey, authMessage);
200
201            // StoredKey := H(ClientKey)
202            byte[] storedKey = SHA1.bytes(clientKey);
203
204            // ClientSignature := HMAC(StoredKey, AuthMessage)
205            byte[] clientSignature = hmac(storedKey, authMessage);
206
207            // ClientProof := ClientKey XOR ClientSignature
208            byte[] clientProof = new byte[clientKey.length];
209            for (int i = 0; i < clientProof.length; i++) {
210                clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
211            }
212
213            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
214            state = State.RESPONSE_SENT;
215            return toBytes(clientFinalMessage);
216        case RESPONSE_SENT:
217            String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
218            if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
219                throw new SmackException("Server final message does not match calculated one");
220            }
221            state = State.VALID_SERVER_RESPONSE;
222            break;
223        default:
224            throw new SmackException("Invalid state");
225        }
226        return null;
227    }
228
229    private String getGS2Header() {
230        String authzidPortion = "";
231        if (authorizationId != null) {
232            authzidPortion = "a=" + authorizationId;
233        }
234
235        String cbName = getChannelBindingName();
236        assert (StringUtils.isNotEmpty(cbName));
237
238        return cbName + ',' + authzidPortion + ",";
239    }
240
241    private byte[] getCBindInput() throws SmackException {
242        byte[] cbindData = getChannelBindingData();
243        byte[] gs2Header = toBytes(getGS2Header());
244
245        if (cbindData == null) {
246            return gs2Header;
247        }
248
249        return ByteUtils.concat(gs2Header, cbindData);
250    }
251
252    protected String getChannelBindingName() {
253        // Check if we are using TLS and if a "-PLUS" variant of this mechanism is enabled. Assuming that the "-PLUS"
254        // variants always have precedence before the non-"-PLUS" variants this means that the server did not announce
255        // the "-PLUS" variant, as otherwise we would have tried it.
256        if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) {
257            // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we
258            // believe the server does not.
259            return "y";
260        }
261        return "n";
262    }
263
264    /**
265     *
266     * @return the Channel Binding data.
267     * @throws SmackException
268     */
269    protected byte[] getChannelBindingData() throws SmackException {
270        return null;
271    }
272
273    private static Map<Character, String> parseAttributes(String string) throws SmackException {
274        if (string.length() == 0) {
275            return Collections.emptyMap();
276        }
277
278        String[] keyValuePairs = string.split(",");
279        Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
280        for (String keyValuePair : keyValuePairs) {
281            if (keyValuePair.length() < 3) {
282                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
283            }
284            char key = keyValuePair.charAt(0);
285            if (keyValuePair.charAt(1) != '=') {
286                throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
287            }
288            String value = keyValuePair.substring(2);
289            res.put(key, value);
290        }
291
292        return res;
293    }
294
295    /**
296     * Generate random ASCII.
297     * <p>
298     * This method is non-static and package-private for unit testing purposes.
299     * </p>
300     * @return A String of 32 random printable ASCII characters.
301     */
302    String getRandomAscii() {
303        int count = 0;
304        char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
305        final Random random = SECURE_RANDOM.get();
306        while (count < RANDOM_ASCII_BYTE_COUNT) {
307            int r = random.nextInt(128);
308            char c = (char) r;
309            // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
310            if (!isPrintableNonCommaAsciiChar(c)) {
311                continue;
312            }
313            randomAscii[count++] = c;
314        }
315        return new String(randomAscii);
316    }
317
318    private static boolean isPrintableNonCommaAsciiChar(char c) {
319        if (c == ',') {
320            return false;
321        }
322        // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126)
323        // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to
324        // ensure that c is within [33, 126].
325        return c > 32 && c < 127;
326    }
327
328    /**
329     * Escapes usernames or passwords for SASL SCRAM-SHA1.
330     * <p>
331     * According to RFC 5802 § 5.1 'n:'
332     * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
333     * </p>
334     *
335     * @param string
336     * @return the escaped string
337     */
338    private static String escape(String string) {
339        StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
340        for (int i = 0; i < string.length(); i++) {
341            char c = string.charAt(i);
342            switch (c) {
343            case ',':
344                sb.append("=2C");
345                break;
346            case '=':
347                sb.append("=3D");
348                break;
349            default:
350                sb.append(c);
351                break;
352            }
353        }
354        return sb.toString();
355    }
356
357    /**
358     * RFC 5802 § 2.2 HMAC(key, str)
359     *
360     * @param key
361     * @param str
362     * @return the HMAC-SHA1 value of the input.
363     * @throws SmackException
364     */
365    private byte[] hmac(byte[] key, byte[] str) throws SmackException {
366        try {
367            return scramHmac.hmac(key, str);
368        }
369        catch (InvalidKeyException e) {
370            throw new SmackException(getName() + " Exception", e);
371        }
372    }
373
374    /**
375     * RFC 5802 § 2.2 Hi(str, salt, i)
376     * <p>
377     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
378     * (PRF) and with dkLen == output length of HMAC() == output length of H().
379     * </p>
380     *
381     * @param normalizedPassword the normalized password.
382     * @param salt
383     * @param iterations
384     * @return the result of the Hi function.
385     * @throws SmackException
386     */
387    private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackException {
388        byte[] key;
389        try {
390            // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8.
391            key = normalizedPassword.getBytes(StringUtils.UTF8);
392        }
393        catch (UnsupportedEncodingException e) {
394            throw new AssertionError();
395        }
396        // U1 := HMAC(str, salt + INT(1))
397        byte[] u = hmac(key, ByteUtils.concat(salt, ONE));
398        byte[] res = u.clone();
399        for (int i = 1; i < iterations; i++) {
400            u = hmac(key, u);
401            for (int j = 0; j < u.length; j++) {
402                res[j] ^= u[j];
403            }
404        }
405        return res;
406    }
407
408    private static class Keys {
409        private final byte[] clientKey;
410        private final byte[] serverKey;
411
412        Keys(byte[] clientKey, byte[] serverKey) {
413            this.clientKey = clientKey;
414            this.serverKey = serverKey;
415        }
416    }
417}