001/**
002 *
003 * Copyright 2014-2020 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.sasl.core;
018
019import java.nio.charset.StandardCharsets;
020import java.security.InvalidKeyException;
021import java.security.SecureRandom;
022import java.util.Collections;
023import java.util.HashMap;
024import java.util.Map;
025import java.util.Random;
026
027import javax.security.auth.callback.CallbackHandler;
028
029import org.jivesoftware.smack.SmackException;
030import org.jivesoftware.smack.SmackException.SmackSaslException;
031import org.jivesoftware.smack.sasl.SASLMechanism;
032import org.jivesoftware.smack.util.ByteUtils;
033import org.jivesoftware.smack.util.SHA1;
034import org.jivesoftware.smack.util.StringUtils;
035import org.jivesoftware.smack.util.stringencoder.Base64;
036
037import org.jxmpp.util.cache.Cache;
038import org.jxmpp.util.cache.LruCache;
039
040public abstract class ScramMechanism extends SASLMechanism {
041
042    private static final int RANDOM_ASCII_BYTE_COUNT = 32;
043    private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
044    private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
045    private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
046
047    private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() {
048        @Override
049        protected SecureRandom initialValue() {
050            return new SecureRandom();
051        }
052    };
053
054    private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
055
056    private final ScramHmac scramHmac;
057
058    protected ScramMechanism(ScramHmac scramHmac) {
059        this.scramHmac = scramHmac;
060    }
061
062    private enum State {
063        INITIAL,
064        AUTH_TEXT_SENT,
065        RESPONSE_SENT,
066        VALID_SERVER_RESPONSE,
067    }
068
069    /**
070     * The state of the this instance of SASL SCRAM-SHA1 authentication.
071     */
072    private State state = State.INITIAL;
073
074    /**
075     * The client's random ASCII which is used as nonce
076     */
077    private String clientRandomAscii;
078
079    private String clientFirstMessageBare;
080    private byte[] serverSignature;
081
082    @Override
083    protected void authenticateInternal(CallbackHandler cbh) {
084        throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
085    }
086
087    @Override
088    protected byte[] getAuthenticationText() {
089        clientRandomAscii = getRandomAscii();
090        String saslPrepedAuthcId = saslPrep(authenticationId);
091        clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
092        String clientFirstMessage = getGS2Header() + clientFirstMessageBare;
093        state = State.AUTH_TEXT_SENT;
094        return toBytes(clientFirstMessage);
095    }
096
097    @Override
098    public String getName() {
099        String name = "SCRAM-" + scramHmac.getHmacName();
100        return name;
101    }
102
103    @Override
104    public void checkIfSuccessfulOrThrow() throws SmackSaslException {
105        if (state != State.VALID_SERVER_RESPONSE) {
106            throw new SmackSaslException("SCRAM-SHA1 is missing valid server response");
107        }
108    }
109
110    @Override
111    public boolean authzidSupported() {
112        return true;
113    }
114
115    @Override
116    protected byte[] evaluateChallenge(byte[] challenge) throws SmackSaslException {
117        // TODO: Where is it specified that this is an UTF-8 encoded string?
118        String challengeString = new String(challenge, StandardCharsets.UTF_8);
119
120        switch (state) {
121        case AUTH_TEXT_SENT:
122            final String serverFirstMessage = challengeString;
123            Map<Character, String> attributes = parseAttributes(challengeString);
124
125            // Handle server random ASCII (nonce)
126            String rvalue = attributes.get('r');
127            if (rvalue == null) {
128                throw new SmackSaslException("Server random ASCII is null");
129            }
130            if (rvalue.length() <= clientRandomAscii.length()) {
131                throw new SmackSaslException("Server random ASCII is shorter then client random ASCII");
132            }
133            String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
134            if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
135                throw new SmackSaslException("Received client random ASCII does not match client random ASCII");
136            }
137
138            // Handle iterations
139            int iterations;
140            String iterationsString = attributes.get('i');
141            if (iterationsString == null) {
142                throw new SmackSaslException("Iterations attribute not set");
143            }
144            try {
145                iterations = Integer.parseInt(iterationsString);
146            }
147            catch (NumberFormatException e) {
148                throw new SmackSaslException("Exception parsing iterations", e);
149            }
150
151            // Handle salt
152            String salt = attributes.get('s');
153            if (salt == null) {
154                throw new SmackSaslException("SALT not send");
155            }
156
157            // Parsing and error checking is done, we can now begin to calculate the values
158
159            // First the client-final-message-without-proof
160            String channelBinding = "c=" + Base64.encodeToString(getCBindInput());
161            String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue;
162
163            // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
164            // client-final-message-without-proof
165            byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
166                            + clientFinalMessageWithoutProof);
167
168            // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
169            // as it is likely that the server is going to advertise the same salt value upon reauthentication."
170            // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple
171            // mechanisms.
172            final String cacheKey = password + ',' + salt + ',' + getName();
173            byte[] serverKey, clientKey;
174            Keys keys = CACHE.lookup(cacheKey);
175            if (keys == null) {
176                // SaltedPassword := Hi(Normalize(password), salt, i)
177                byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
178
179                // ServerKey := HMAC(SaltedPassword, "Server Key")
180                serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
181
182                // ClientKey := HMAC(SaltedPassword, "Client Key")
183                clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
184
185                keys = new Keys(clientKey, serverKey);
186                CACHE.put(cacheKey, keys);
187            }
188            else {
189                serverKey = keys.serverKey;
190                clientKey = keys.clientKey;
191            }
192
193            // ServerSignature := HMAC(ServerKey, AuthMessage)
194            serverSignature = hmac(serverKey, authMessage);
195
196            // StoredKey := H(ClientKey)
197            byte[] storedKey = SHA1.bytes(clientKey);
198
199            // ClientSignature := HMAC(StoredKey, AuthMessage)
200            byte[] clientSignature = hmac(storedKey, authMessage);
201
202            // ClientProof := ClientKey XOR ClientSignature
203            byte[] clientProof = new byte[clientKey.length];
204            for (int i = 0; i < clientProof.length; i++) {
205                clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
206            }
207
208            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
209            state = State.RESPONSE_SENT;
210            return toBytes(clientFinalMessage);
211        case RESPONSE_SENT:
212            String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
213            if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
214                throw new SmackSaslException("Server final message does not match calculated one");
215            }
216            state = State.VALID_SERVER_RESPONSE;
217            break;
218        default:
219            throw new SmackSaslException("Invalid state");
220        }
221        return null;
222    }
223
224    private String getGS2Header() {
225        String authzidPortion = "";
226        if (authorizationId != null) {
227            authzidPortion = "a=" + authorizationId;
228        }
229
230        String cbName = getGs2CbindFlag();
231        assert StringUtils.isNotEmpty(cbName);
232
233        return cbName + ',' + authzidPortion + ",";
234    }
235
236    private byte[] getCBindInput() throws SmackSaslException {
237        byte[] cbindData = getChannelBindingData();
238        byte[] gs2Header = toBytes(getGS2Header());
239
240        if (cbindData == null) {
241            return gs2Header;
242        }
243
244        return ByteUtils.concat(gs2Header, cbindData);
245    }
246
247    /**
248     * Get the SCRAM GSS-API Channel Binding Flag value.
249     *
250     * @return the gs2-cbind-flag value.
251     * @see <a href="https://tools.ietf.org/html/rfc5802#section-6">RFC 5802 § 6.</a>
252     */
253    protected String getGs2CbindFlag() {
254        // Check if we are using TLS and if a "-PLUS" variant of this mechanism is enabled. Assuming that the "-PLUS"
255        // variants always have precedence before the non-"-PLUS" variants this means that the server did not announce
256        // the "-PLUS" variant, as otherwise we would have tried it.
257        if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) {
258            // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we
259            // believe the server does not.
260            return "y";
261        }
262        return "n";
263    }
264
265    /**
266     *
267     * @return the Channel Binding data.
268     * @throws SmackSaslException if a SASL specific error occurred.
269     */
270    protected byte[] getChannelBindingData() throws SmackSaslException {
271        return null;
272    }
273
274    private static Map<Character, String> parseAttributes(String string) throws SmackSaslException {
275        if (string.length() == 0) {
276            return Collections.emptyMap();
277        }
278
279        String[] keyValuePairs = string.split(",");
280        Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
281        for (String keyValuePair : keyValuePairs) {
282            if (keyValuePair.length() < 3) {
283                throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
284            }
285            char key = keyValuePair.charAt(0);
286            if (keyValuePair.charAt(1) != '=') {
287                throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
288            }
289            String value = keyValuePair.substring(2);
290            res.put(key, value);
291        }
292
293        return res;
294    }
295
296    /**
297     * Generate random ASCII.
298     * <p>
299     * This method is non-static and package-private for unit testing purposes.
300     * </p>
301     * @return A String of 32 random printable ASCII characters.
302     */
303    String getRandomAscii() {
304        int count = 0;
305        char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
306        final Random random = SECURE_RANDOM.get();
307        while (count < RANDOM_ASCII_BYTE_COUNT) {
308            int r = random.nextInt(128);
309            char c = (char) r;
310            // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
311            if (!isPrintableNonCommaAsciiChar(c)) {
312                continue;
313            }
314            randomAscii[count++] = c;
315        }
316        return new String(randomAscii);
317    }
318
319    private static boolean isPrintableNonCommaAsciiChar(char c) {
320        if (c == ',') {
321            return false;
322        }
323        // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126)
324        // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to
325        // ensure that c is within [33, 126].
326        return c > 32 && c < 127;
327    }
328
329    /**
330     * Escapes usernames or passwords for SASL SCRAM-SHA1.
331     * <p>
332     * According to RFC 5802 § 5.1 'n:'
333     * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
334     * </p>
335     *
336     * @param string TODO javadoc me please
337     * @return the escaped string
338     */
339    private static String escape(String string) {
340        StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
341        for (int i = 0; i < string.length(); i++) {
342            char c = string.charAt(i);
343            switch (c) {
344            case ',':
345                sb.append("=2C");
346                break;
347            case '=':
348                sb.append("=3D");
349                break;
350            default:
351                sb.append(c);
352                break;
353            }
354        }
355        return sb.toString();
356    }
357
358    /**
359     * RFC 5802 § 2.2 HMAC(key, str)
360     *
361     * @param key TODO javadoc me please
362     * @param str TODO javadoc me please
363     * @return the HMAC-SHA1 value of the input.
364     * @throws SmackException if Smack detected an exceptional situation.
365     */
366    private byte[] hmac(byte[] key, byte[] str) throws SmackSaslException {
367        try {
368            return scramHmac.hmac(key, str);
369        }
370        catch (InvalidKeyException e) {
371            throw new SmackSaslException(getName() + " Exception", e);
372        }
373    }
374
375    /**
376     * RFC 5802 § 2.2 Hi(str, salt, i)
377     * <p>
378     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
379     * (PRF) and with dkLen == output length of HMAC() == output length of H().
380     * </p>
381     *
382     * @param normalizedPassword the normalized password.
383     * @param salt TODO javadoc me please
384     * @param iterations TODO javadoc me please
385     * @return the result of the Hi function.
386     * @throws SmackSaslException if a SASL related error occurs.
387     */
388    private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackSaslException {
389        // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8.
390        byte[] key = normalizedPassword.getBytes(StandardCharsets.UTF_8);
391
392        // U1 := HMAC(str, salt + INT(1))
393        byte[] u = hmac(key, ByteUtils.concat(salt, ONE));
394        byte[] res = u.clone();
395        for (int i = 1; i < iterations; i++) {
396            u = hmac(key, u);
397            for (int j = 0; j < u.length; j++) {
398                res[j] ^= u[j];
399            }
400        }
401        return res;
402    }
403
404    private static class Keys {
405        private final byte[] clientKey;
406        private final byte[] serverKey;
407
408        Keys(byte[] clientKey, byte[] serverKey) {
409            this.clientKey = clientKey;
410            this.serverKey = serverKey;
411        }
412    }
413}