001/**
002 *
003 * Copyright 2014-2021 Florian Schmaus
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smack.sasl.core;
018
019import java.nio.charset.StandardCharsets;
020import java.security.InvalidKeyException;
021import java.security.SecureRandom;
022import java.util.Collections;
023import java.util.HashMap;
024import java.util.Map;
025import java.util.Random;
026
027import javax.security.auth.callback.CallbackHandler;
028
029import org.jivesoftware.smack.SmackException.SmackSaslException;
030import org.jivesoftware.smack.sasl.SASLMechanism;
031import org.jivesoftware.smack.util.ByteUtils;
032import org.jivesoftware.smack.util.SHA1;
033import org.jivesoftware.smack.util.StringUtils;
034import org.jivesoftware.smack.util.stringencoder.Base64;
035
036import org.jxmpp.util.cache.Cache;
037import org.jxmpp.util.cache.LruCache;
038
039public abstract class ScramMechanism extends SASLMechanism {
040
041    private static final int RANDOM_ASCII_BYTE_COUNT = 32;
042    private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
043    private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
044    private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
045
046    private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() {
047        @Override
048        protected SecureRandom initialValue() {
049            return new SecureRandom();
050        }
051    };
052
053    private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
054
055    private final ScramHmac scramHmac;
056
057    protected ScramMechanism(ScramHmac scramHmac) {
058        this.scramHmac = scramHmac;
059    }
060
061    private enum State {
062        INITIAL,
063        AUTH_TEXT_SENT,
064        RESPONSE_SENT,
065        VALID_SERVER_RESPONSE,
066    }
067
068    /**
069     * The state of the this instance of SASL SCRAM-SHA1 authentication.
070     */
071    private State state = State.INITIAL;
072
073    /**
074     * The client's random ASCII which is used as nonce
075     */
076    private String clientRandomAscii;
077
078    private String clientFirstMessageBare;
079    private byte[] serverSignature;
080
081    @Override
082    protected void authenticateInternal(CallbackHandler cbh) {
083        throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
084    }
085
086    @Override
087    protected byte[] getAuthenticationText() {
088        clientRandomAscii = getRandomAscii();
089        String saslPrepedAuthcId = saslPrep(authenticationId);
090        clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
091        String clientFirstMessage = getGS2Header() + clientFirstMessageBare;
092        state = State.AUTH_TEXT_SENT;
093        return toBytes(clientFirstMessage);
094    }
095
096    @Override
097    public String getName() {
098        String name = "SCRAM-" + scramHmac.getHmacName();
099        return name;
100    }
101
102    @Override
103    public void checkIfSuccessfulOrThrow() throws SmackSaslException {
104        if (state != State.VALID_SERVER_RESPONSE) {
105            throw new SmackSaslException("SCRAM-SHA1 is missing valid server response");
106        }
107    }
108
109    @Override
110    public boolean authzidSupported() {
111        return true;
112    }
113
114    @Override
115    protected byte[] evaluateChallenge(byte[] challenge) throws SmackSaslException {
116        // TODO: Where is it specified that this is an UTF-8 encoded string?
117        String challengeString = new String(challenge, StandardCharsets.UTF_8);
118
119        switch (state) {
120        case AUTH_TEXT_SENT:
121            final String serverFirstMessage = challengeString;
122            Map<Character, String> attributes = parseAttributes(challengeString);
123
124            // Handle server random ASCII (nonce)
125            String rvalue = attributes.get('r');
126            if (rvalue == null) {
127                throw new SmackSaslException("Server random ASCII is null");
128            }
129            if (rvalue.length() <= clientRandomAscii.length()) {
130                throw new SmackSaslException("Server random ASCII is shorter then client random ASCII");
131            }
132            String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
133            if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
134                throw new SmackSaslException("Received client random ASCII does not match client random ASCII");
135            }
136
137            // Handle iterations
138            int iterations;
139            String iterationsString = attributes.get('i');
140            if (iterationsString == null) {
141                throw new SmackSaslException("Iterations attribute not set");
142            }
143            try {
144                iterations = Integer.parseInt(iterationsString);
145            }
146            catch (NumberFormatException e) {
147                throw new SmackSaslException("Exception parsing iterations", e);
148            }
149
150            // Handle salt
151            String salt = attributes.get('s');
152            if (salt == null) {
153                throw new SmackSaslException("SALT not send");
154            }
155
156            // Parsing and error checking is done, we can now begin to calculate the values
157
158            // First the client-final-message-without-proof
159            String channelBinding = "c=" + Base64.encodeToString(getCBindInput());
160            String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue;
161
162            // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
163            // client-final-message-without-proof
164            byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
165                            + clientFinalMessageWithoutProof);
166
167            // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
168            // as it is likely that the server is going to advertise the same salt value upon reauthentication."
169            // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple
170            // mechanisms.
171            final String cacheKey = password + ',' + salt + ',' + getName();
172            byte[] serverKey, clientKey;
173            Keys keys = CACHE.lookup(cacheKey);
174            if (keys == null) {
175                // SaltedPassword := Hi(Normalize(password), salt, i)
176                byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
177
178                // ServerKey := HMAC(SaltedPassword, "Server Key")
179                serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
180
181                // ClientKey := HMAC(SaltedPassword, "Client Key")
182                clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
183
184                keys = new Keys(clientKey, serverKey);
185                CACHE.put(cacheKey, keys);
186            }
187            else {
188                serverKey = keys.serverKey;
189                clientKey = keys.clientKey;
190            }
191
192            // ServerSignature := HMAC(ServerKey, AuthMessage)
193            serverSignature = hmac(serverKey, authMessage);
194
195            // StoredKey := H(ClientKey)
196            byte[] storedKey = SHA1.bytes(clientKey);
197
198            // ClientSignature := HMAC(StoredKey, AuthMessage)
199            byte[] clientSignature = hmac(storedKey, authMessage);
200
201            // ClientProof := ClientKey XOR ClientSignature
202            byte[] clientProof = new byte[clientKey.length];
203            for (int i = 0; i < clientProof.length; i++) {
204                clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
205            }
206
207            String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
208            state = State.RESPONSE_SENT;
209            return toBytes(clientFinalMessage);
210        case RESPONSE_SENT:
211            String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
212            if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
213                throw new SmackSaslException("Server final message does not match calculated one");
214            }
215            state = State.VALID_SERVER_RESPONSE;
216            break;
217        default:
218            throw new SmackSaslException("Invalid state");
219        }
220        return null;
221    }
222
223    private String getGS2Header() {
224        String authzidPortion = "";
225        if (authorizationId != null) {
226            authzidPortion = "a=" + authorizationId;
227        }
228
229        String cbName = getGs2CbindFlag();
230        assert StringUtils.isNotEmpty(cbName);
231
232        return cbName + ',' + authzidPortion + ",";
233    }
234
235    private byte[] getCBindInput() throws SmackSaslException {
236        byte[] cbindData = getChannelBindingData();
237        byte[] gs2Header = toBytes(getGS2Header());
238
239        if (cbindData == null) {
240            return gs2Header;
241        }
242
243        return ByteUtils.concat(gs2Header, cbindData);
244    }
245
246    /**
247     * Get the SCRAM GSS-API Channel Binding Flag value.
248     *
249     * @return the gs2-cbind-flag value.
250     * @see <a href="https://tools.ietf.org/html/rfc5802#section-6">RFC 5802 § 6.</a>
251     */
252    protected String getGs2CbindFlag() {
253        // Check if we are using TLS and if a "-PLUS" variant of this mechanism is enabled. Assuming that the "-PLUS"
254        // variants always have precedence before the non-"-PLUS" variants this means that the server did not announce
255        // the "-PLUS" variant, as otherwise we would have tried it.
256        if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) {
257            // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we
258            // believe the server does not.
259            return "y";
260        }
261        return "n";
262    }
263
264    /**
265     * Get the channel binding data.
266     *
267     * @return the Channel Binding data.
268     * @throws SmackSaslException if a SASL specific error occurred.
269     */
270    protected byte[] getChannelBindingData() throws SmackSaslException {
271        return null;
272    }
273
274    private static Map<Character, String> parseAttributes(String string) throws SmackSaslException {
275        if (string.length() == 0) {
276            return Collections.emptyMap();
277        }
278
279        String[] keyValuePairs = string.split(",");
280        Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
281        for (String keyValuePair : keyValuePairs) {
282            if (keyValuePair.length() < 3) {
283                throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
284            }
285            char key = keyValuePair.charAt(0);
286            if (keyValuePair.charAt(1) != '=') {
287                throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
288            }
289            String value = keyValuePair.substring(2);
290            res.put(key, value);
291        }
292
293        return res;
294    }
295
296    /**
297     * Generate random ASCII.
298     * <p>
299     * This method is non-static and package-private for unit testing purposes.
300     * </p>
301     * @return A String of 32 random printable ASCII characters.
302     */
303    String getRandomAscii() {
304        int count = 0;
305        char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
306        final Random random = SECURE_RANDOM.get();
307        while (count < RANDOM_ASCII_BYTE_COUNT) {
308            int r = random.nextInt(128);
309            char c = (char) r;
310            // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
311            if (!isPrintableNonCommaAsciiChar(c)) {
312                continue;
313            }
314            randomAscii[count++] = c;
315        }
316        return new String(randomAscii);
317    }
318
319    private static boolean isPrintableNonCommaAsciiChar(char c) {
320        if (c == ',') {
321            return false;
322        }
323        // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126)
324        // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to
325        // ensure that c is within [33, 126].
326        return c > 32 && c < 127;
327    }
328
329    /**
330     * Escapes usernames or passwords for SASL SCRAM-SHA1.
331     * <p>
332     * According to RFC 5802 § 5.1 'n:'
333     * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
334     * </p>
335     *
336     * @param string TODO javadoc me please
337     * @return the escaped string
338     */
339    private static String escape(String string) {
340        StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
341        for (int i = 0; i < string.length(); i++) {
342            char c = string.charAt(i);
343            switch (c) {
344            case ',':
345                sb.append("=2C");
346                break;
347            case '=':
348                sb.append("=3D");
349                break;
350            default:
351                sb.append(c);
352                break;
353            }
354        }
355        return sb.toString();
356    }
357
358    /**
359     * RFC 5802 § 2.2 HMAC(key, str)
360     *
361     * @param key TODO javadoc me please
362     * @param str TODO javadoc me please
363     * @return the HMAC-SHA1 value of the input.
364     * @throws SmackSaslException if Smack detected an exceptional situation.
365     */
366    private byte[] hmac(byte[] key, byte[] str) throws SmackSaslException {
367        try {
368            return scramHmac.hmac(key, str);
369        }
370        catch (InvalidKeyException e) {
371            throw new SmackSaslException(getName() + " Exception", e);
372        }
373    }
374
375    /**
376     * RFC 5802 § 2.2 Hi(str, salt, i)
377     * <p>
378     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
379     * (PRF) and with dkLen == output length of HMAC() == output length of H().
380     * </p>
381     *
382     * @param normalizedPassword the normalized password.
383     * @param salt TODO javadoc me please
384     * @param iterations TODO javadoc me please
385     * @return the result of the Hi function.
386     * @throws SmackSaslException if a SASL related error occurs.
387     */
388    private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackSaslException {
389        // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8.
390        byte[] key = normalizedPassword.getBytes(StandardCharsets.UTF_8);
391
392        // U1 := HMAC(str, salt + INT(1))
393        byte[] u = hmac(key, ByteUtils.concat(salt, ONE));
394        byte[] res = u.clone();
395        for (int i = 1; i < iterations; i++) {
396            u = hmac(key, u);
397            for (int j = 0; j < u.length; j++) {
398                res[j] ^= u[j];
399            }
400        }
401        return res;
402    }
403
404    private static class Keys {
405        private final byte[] clientKey;
406        private final byte[] serverKey;
407
408        Keys(byte[] clientKey, byte[] serverKey) {
409            this.clientKey = clientKey;
410            this.serverKey = serverKey;
411        }
412    }
413}