- /**
- *
- * Copyright 2014-2021 Florian Schmaus
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.jivesoftware.smack.sasl.core;
- import java.nio.charset.StandardCharsets;
- import java.security.InvalidKeyException;
- import java.security.SecureRandom;
- import java.util.Collections;
- import java.util.HashMap;
- import java.util.Map;
- import java.util.Random;
- import javax.security.auth.callback.CallbackHandler;
- import org.jivesoftware.smack.SmackException.SmackSaslException;
- import org.jivesoftware.smack.sasl.SASLMechanism;
- import org.jivesoftware.smack.util.ByteUtils;
- import org.jivesoftware.smack.util.SHA1;
- import org.jivesoftware.smack.util.StringUtils;
- import org.jivesoftware.smack.util.stringencoder.Base64;
- import org.jxmpp.util.cache.Cache;
- import org.jxmpp.util.cache.LruCache;
- public abstract class ScramMechanism extends SASLMechanism {
- private static final int RANDOM_ASCII_BYTE_COUNT = 32;
- private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
- private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
- private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
- private static final ThreadLocal<SecureRandom> SECURE_RANDOM = new ThreadLocal<SecureRandom>() {
- @Override
- protected SecureRandom initialValue() {
- return new SecureRandom();
- }
- };
- private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
- private final ScramHmac scramHmac;
- protected ScramMechanism(ScramHmac scramHmac) {
- this.scramHmac = scramHmac;
- }
- private enum State {
- INITIAL,
- AUTH_TEXT_SENT,
- RESPONSE_SENT,
- VALID_SERVER_RESPONSE,
- }
- /**
- * The state of the this instance of SASL SCRAM-SHA1 authentication.
- */
- private State state = State.INITIAL;
- /**
- * The client's random ASCII which is used as nonce
- */
- private String clientRandomAscii;
- private String clientFirstMessageBare;
- private byte[] serverSignature;
- @Override
- protected void authenticateInternal(CallbackHandler cbh) {
- throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
- }
- @Override
- protected byte[] getAuthenticationText() {
- clientRandomAscii = getRandomAscii();
- String saslPrepedAuthcId = saslPrep(authenticationId);
- clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
- String clientFirstMessage = getGS2Header() + clientFirstMessageBare;
- state = State.AUTH_TEXT_SENT;
- return toBytes(clientFirstMessage);
- }
- @Override
- public String getName() {
- String name = "SCRAM-" + scramHmac.getHmacName();
- return name;
- }
- @Override
- public void checkIfSuccessfulOrThrow() throws SmackSaslException {
- if (state != State.VALID_SERVER_RESPONSE) {
- throw new SmackSaslException("SCRAM-SHA1 is missing valid server response");
- }
- }
- @Override
- public boolean authzidSupported() {
- return true;
- }
- @Override
- protected byte[] evaluateChallenge(byte[] challenge) throws SmackSaslException {
- // TODO: Where is it specified that this is an UTF-8 encoded string?
- String challengeString = new String(challenge, StandardCharsets.UTF_8);
- switch (state) {
- case AUTH_TEXT_SENT:
- final String serverFirstMessage = challengeString;
- Map<Character, String> attributes = parseAttributes(challengeString);
- // Handle server random ASCII (nonce)
- String rvalue = attributes.get('r');
- if (rvalue == null) {
- throw new SmackSaslException("Server random ASCII is null");
- }
- if (rvalue.length() <= clientRandomAscii.length()) {
- throw new SmackSaslException("Server random ASCII is shorter then client random ASCII");
- }
- String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
- if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
- throw new SmackSaslException("Received client random ASCII does not match client random ASCII");
- }
- // Handle iterations
- int iterations;
- String iterationsString = attributes.get('i');
- if (iterationsString == null) {
- throw new SmackSaslException("Iterations attribute not set");
- }
- try {
- iterations = Integer.parseInt(iterationsString);
- }
- catch (NumberFormatException e) {
- throw new SmackSaslException("Exception parsing iterations", e);
- }
- // Handle salt
- String salt = attributes.get('s');
- if (salt == null) {
- throw new SmackSaslException("SALT not send");
- }
- // Parsing and error checking is done, we can now begin to calculate the values
- // First the client-final-message-without-proof
- String channelBinding = "c=" + Base64.encodeToString(getCBindInput());
- String clientFinalMessageWithoutProof = channelBinding + ",r=" + rvalue;
- // AuthMessage := client-first-message-bare + "," + server-first-message + "," +
- // client-final-message-without-proof
- byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
- + clientFinalMessageWithoutProof);
- // RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
- // as it is likely that the server is going to advertise the same salt value upon reauthentication."
- // Note that we also mangle the mechanism's name into the cache key, since the cache is used by multiple
- // mechanisms.
- final String cacheKey = password + ',' + salt + ',' + getName();
- byte[] serverKey, clientKey;
- Keys keys = CACHE.lookup(cacheKey);
- if (keys == null) {
- // SaltedPassword := Hi(Normalize(password), salt, i)
- byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
- // ServerKey := HMAC(SaltedPassword, "Server Key")
- serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
- // ClientKey := HMAC(SaltedPassword, "Client Key")
- clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
- keys = new Keys(clientKey, serverKey);
- CACHE.put(cacheKey, keys);
- }
- else {
- serverKey = keys.serverKey;
- clientKey = keys.clientKey;
- }
- // ServerSignature := HMAC(ServerKey, AuthMessage)
- serverSignature = hmac(serverKey, authMessage);
- // StoredKey := H(ClientKey)
- byte[] storedKey = SHA1.bytes(clientKey);
- // ClientSignature := HMAC(StoredKey, AuthMessage)
- byte[] clientSignature = hmac(storedKey, authMessage);
- // ClientProof := ClientKey XOR ClientSignature
- byte[] clientProof = new byte[clientKey.length];
- for (int i = 0; i < clientProof.length; i++) {
- clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
- }
- String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
- state = State.RESPONSE_SENT;
- return toBytes(clientFinalMessage);
- case RESPONSE_SENT:
- String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
- if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
- throw new SmackSaslException("Server final message does not match calculated one");
- }
- state = State.VALID_SERVER_RESPONSE;
- break;
- default:
- throw new SmackSaslException("Invalid state");
- }
- return null;
- }
- private String getGS2Header() {
- String authzidPortion = "";
- if (authorizationId != null) {
- authzidPortion = "a=" + authorizationId;
- }
- String cbName = getGs2CbindFlag();
- assert StringUtils.isNotEmpty(cbName);
- return cbName + ',' + authzidPortion + ",";
- }
- private byte[] getCBindInput() throws SmackSaslException {
- byte[] cbindData = getChannelBindingData();
- byte[] gs2Header = toBytes(getGS2Header());
- if (cbindData == null) {
- return gs2Header;
- }
- return ByteUtils.concat(gs2Header, cbindData);
- }
- /**
- * Get the SCRAM GSS-API Channel Binding Flag value.
- *
- * @return the gs2-cbind-flag value.
- * @see <a href="https://tools.ietf.org/html/rfc5802#section-6">RFC 5802 § 6.</a>
- */
- protected String getGs2CbindFlag() {
- // Check if we are using TLS and if a "-PLUS" variant of this mechanism is enabled. Assuming that the "-PLUS"
- // variants always have precedence before the non-"-PLUS" variants this means that the server did not announce
- // the "-PLUS" variant, as otherwise we would have tried it.
- if (sslSession != null && connectionConfiguration.isEnabledSaslMechanism(getName() + "-PLUS")) {
- // Announce that we support Channel Binding, i.e., the '-PLUS' flavor of this SASL mechanism, but that we
- // believe the server does not.
- return "y";
- }
- return "n";
- }
- /**
- * Get the channel binding data.
- *
- * @return the Channel Binding data.
- * @throws SmackSaslException if a SASL specific error occurred.
- */
- protected byte[] getChannelBindingData() throws SmackSaslException {
- return null;
- }
- private static Map<Character, String> parseAttributes(String string) throws SmackSaslException {
- if (string.length() == 0) {
- return Collections.emptyMap();
- }
- String[] keyValuePairs = string.split(",");
- Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
- for (String keyValuePair : keyValuePairs) {
- if (keyValuePair.length() < 3) {
- throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
- }
- char key = keyValuePair.charAt(0);
- if (keyValuePair.charAt(1) != '=') {
- throw new SmackSaslException("Invalid Key-Value pair: " + keyValuePair);
- }
- String value = keyValuePair.substring(2);
- res.put(key, value);
- }
- return res;
- }
- /**
- * Generate random ASCII.
- * <p>
- * This method is non-static and package-private for unit testing purposes.
- * </p>
- * @return A String of 32 random printable ASCII characters.
- */
- String getRandomAscii() {
- int count = 0;
- char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
- final Random random = SECURE_RANDOM.get();
- while (count < RANDOM_ASCII_BYTE_COUNT) {
- int r = random.nextInt(128);
- char c = (char) r;
- // RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
- if (!isPrintableNonCommaAsciiChar(c)) {
- continue;
- }
- randomAscii[count++] = c;
- }
- return new String(randomAscii);
- }
- private static boolean isPrintableNonCommaAsciiChar(char c) {
- if (c == ',') {
- return false;
- }
- // RFC 5802 § 7. 'printable': Contains all chars within 0x21 (33d) to 0x2b (43d) and 0x2d (45d) to 0x7e (126)
- // aka. "Printable ASCII except ','". Since we already filter the ASCII ',' (0x2c, 44d) above, we only have to
- // ensure that c is within [33, 126].
- return c > 32 && c < 127;
- }
- /**
- * Escapes usernames or passwords for SASL SCRAM-SHA1.
- * <p>
- * According to RFC 5802 § 5.1 'n:'
- * "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
- * </p>
- *
- * @param string TODO javadoc me please
- * @return the escaped string
- */
- private static String escape(String string) {
- StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
- for (int i = 0; i < string.length(); i++) {
- char c = string.charAt(i);
- switch (c) {
- case ',':
- sb.append("=2C");
- break;
- case '=':
- sb.append("=3D");
- break;
- default:
- sb.append(c);
- break;
- }
- }
- return sb.toString();
- }
- /**
- * RFC 5802 § 2.2 HMAC(key, str)
- *
- * @param key TODO javadoc me please
- * @param str TODO javadoc me please
- * @return the HMAC-SHA1 value of the input.
- * @throws SmackSaslException if Smack detected an exceptional situation.
- */
- private byte[] hmac(byte[] key, byte[] str) throws SmackSaslException {
- try {
- return scramHmac.hmac(key, str);
- }
- catch (InvalidKeyException e) {
- throw new SmackSaslException(getName() + " Exception", e);
- }
- }
- /**
- * RFC 5802 § 2.2 Hi(str, salt, i)
- * <p>
- * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
- * (PRF) and with dkLen == output length of HMAC() == output length of H().
- * </p>
- *
- * @param normalizedPassword the normalized password.
- * @param salt TODO javadoc me please
- * @param iterations TODO javadoc me please
- * @return the result of the Hi function.
- * @throws SmackSaslException if a SASL related error occurs.
- */
- private byte[] hi(String normalizedPassword, byte[] salt, int iterations) throws SmackSaslException {
- // According to RFC 5802 § 2.2, the resulting string of the normalization is also in UTF-8.
- byte[] key = normalizedPassword.getBytes(StandardCharsets.UTF_8);
- // U1 := HMAC(str, salt + INT(1))
- byte[] u = hmac(key, ByteUtils.concat(salt, ONE));
- byte[] res = u.clone();
- for (int i = 1; i < iterations; i++) {
- u = hmac(key, u);
- for (int j = 0; j < u.length; j++) {
- res[j] ^= u[j];
- }
- }
- return res;
- }
- private static class Keys {
- private final byte[] clientKey;
- private final byte[] serverKey;
- Keys(byte[] clientKey, byte[] serverKey) {
- this.clientKey = clientKey;
- this.serverKey = serverKey;
- }
- }
- }