CachingOmemoStore.java

/**
 *
 * Copyright 2017 Paul Schaub
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.jivesoftware.smackx.omemo;

import java.io.IOException;
import java.util.Date;
import java.util.HashMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;

import org.jivesoftware.smackx.omemo.exceptions.CorruptedOmemoKeyException;
import org.jivesoftware.smackx.omemo.internal.OmemoCachedDeviceList;
import org.jivesoftware.smackx.omemo.internal.OmemoDevice;
import org.jivesoftware.smackx.omemo.util.OmemoKeyUtil;

import org.jxmpp.jid.BareJid;

/**
 * This class implements the Proxy Pattern in order to wrap an OmemoStore with a caching layer.
 * This reduces access to the underlying storage layer (eg. database, filesystem) by only accessing it for
 * missing/updated values.
 *
 * Alternatively this implementation can be used as an ephemeral keystore without a persisting backend.
 *
 * @param <T_IdKeyPair> the type of the id key pair.
 * @param <T_IdKey> the type of the id key.
 * @param <T_PreKey> the prekey type
 * @param <T_SigPreKey> the signed prekey type.
 * @param <T_Sess> the session type.
 * @param <T_Addr> the address type.
 * @param <T_ECPub> the EC pub type.
 * @param <T_Bundle> the bundle type.
 * @param <T_Ciph> the cipher type.
 */
public class CachingOmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph>
        extends OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> {

    private final HashMap<OmemoDevice, KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess>> caches = new HashMap<>();
    private final OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> persistent;
    private final OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil;

    public CachingOmemoStore(OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil) {
        if (keyUtil == null) {
            throw new IllegalArgumentException("KeyUtil MUST NOT be null!");
        }
        this.keyUtil = keyUtil;
        persistent = null;
    }

    public CachingOmemoStore(OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> wrappedStore) {
        if (wrappedStore == null) {
            throw new NullPointerException("Wrapped OmemoStore MUST NOT be null!");
        }
        this.keyUtil = null;
        persistent = wrappedStore;
    }

    @Override
    public SortedSet<Integer> localDeviceIdsOf(BareJid localUser) {
        if (persistent != null) {
            return persistent.localDeviceIdsOf(localUser);
        } else {
            SortedSet<Integer> deviceIds = new TreeSet<>();
            for (OmemoDevice device : caches.keySet()) {
                if (device.getJid().equals(localUser)) {
                    deviceIds.add(device.getDeviceId());
                }
            }
            return deviceIds;
        }
    }

    @Override
    public T_IdKeyPair loadOmemoIdentityKeyPair(OmemoDevice userDevice)
            throws CorruptedOmemoKeyException, IOException {
        T_IdKeyPair pair = getCache(userDevice).identityKeyPair;

        if (pair == null && persistent != null) {
            pair = persistent.loadOmemoIdentityKeyPair(userDevice);
            if (pair != null) {
                getCache(userDevice).identityKeyPair = pair;
            }
        }

        return pair;
    }

    @Override
    public void storeOmemoIdentityKeyPair(OmemoDevice userDevice, T_IdKeyPair identityKeyPair) throws IOException {
        getCache(userDevice).identityKeyPair = identityKeyPair;
        if (persistent != null) {
            persistent.storeOmemoIdentityKeyPair(userDevice, identityKeyPair);
        }
    }

    @Override
    public void removeOmemoIdentityKeyPair(OmemoDevice userDevice) {
        getCache(userDevice).identityKeyPair = null;
        if (persistent != null) {
            persistent.removeOmemoIdentityKeyPair(userDevice);
        }
    }

    @Override
    public T_IdKey loadOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice)
            throws CorruptedOmemoKeyException, IOException {
        T_IdKey idKey = getCache(userDevice).identityKeys.get(contactsDevice);

        if (idKey == null && persistent != null) {
            idKey = persistent.loadOmemoIdentityKey(userDevice, contactsDevice);
            if (idKey != null) {
                getCache(userDevice).identityKeys.put(contactsDevice, idKey);
            }
        }

        return idKey;
    }

    @Override
    public void storeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice device, T_IdKey t_idKey) throws IOException {
        getCache(userDevice).identityKeys.put(device, t_idKey);
        if (persistent != null) {
            persistent.storeOmemoIdentityKey(userDevice, device, t_idKey);
        }
    }

    @Override
    public void removeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice) {
        getCache(userDevice).identityKeys.remove(contactsDevice);
        if (persistent != null) {
            persistent.removeOmemoIdentityKey(userDevice, contactsDevice);
        }
    }

    @Override
    public void storeOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice, int counter) throws IOException {
        getCache(userDevice).messageCounters.put(contactsDevice, counter);
        if (persistent != null) {
            persistent.storeOmemoMessageCounter(userDevice, contactsDevice, counter);
        }
    }

    @Override
    public int loadOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
        Integer counter = getCache(userDevice).messageCounters.get(contactsDevice);
        if (counter == null && persistent != null) {
            counter = persistent.loadOmemoMessageCounter(userDevice, contactsDevice);
        }

        if (counter == null) {
            counter = 0;
        }

        getCache(userDevice).messageCounters.put(contactsDevice, counter);

        return counter;
    }

    @Override
    public void setDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from, Date date) throws IOException {
        getCache(userDevice).lastMessagesDates.put(from, date);
        if (persistent != null) {
            persistent.setDateOfLastReceivedMessage(userDevice, from, date);
        }
    }

    @Override
    public Date getDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from) throws IOException {
        Date last = getCache(userDevice).lastMessagesDates.get(from);

        if (last == null && persistent != null) {
            last = persistent.getDateOfLastReceivedMessage(userDevice, from);
            if (last != null) {
                getCache(userDevice).lastMessagesDates.put(from, last);
            }
        }

        return last;
    }

    @Override
    public void setDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice, Date date) throws IOException {
        getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, date);
        if (persistent != null) {
            persistent.setDateOfLastReceivedMessage(userDevice, contactsDevice, date);
        }
    }

    @Override
    public Date getDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
        Date last = getCache(userDevice).lastDeviceIdPublicationDates.get(contactsDevice);

        if (last == null && persistent != null) {
            last = persistent.getDateOfLastDeviceIdPublication(userDevice, contactsDevice);
            if (last != null) {
                getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, last);
            }
        }

        return last;
    }

    @Override
    public void setDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice, Date date) throws IOException {
        getCache(userDevice).lastRenewalDate = date;
        if (persistent != null) {
            persistent.setDateOfLastSignedPreKeyRenewal(userDevice, date);
        }
    }

    @Override
    public Date getDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice) throws IOException {
        Date lastRenewal = getCache(userDevice).lastRenewalDate;

        if (lastRenewal == null && persistent != null) {
            lastRenewal = persistent.getDateOfLastSignedPreKeyRenewal(userDevice);
            if (lastRenewal != null) {
                getCache(userDevice).lastRenewalDate = lastRenewal;
            }
        }

        return lastRenewal;
    }

    @Override
    public T_PreKey loadOmemoPreKey(OmemoDevice userDevice, int preKeyId) throws IOException {
        T_PreKey preKey = getCache(userDevice).preKeys.get(preKeyId);

        if (preKey == null && persistent != null) {
            preKey = persistent.loadOmemoPreKey(userDevice, preKeyId);
            if (preKey != null) {
                getCache(userDevice).preKeys.put(preKeyId, preKey);
            }
        }

        return preKey;
    }

    @Override
    public void storeOmemoPreKey(OmemoDevice userDevice, int preKeyId, T_PreKey t_preKey) throws IOException {
        getCache(userDevice).preKeys.put(preKeyId, t_preKey);
        if (persistent != null) {
            persistent.storeOmemoPreKey(userDevice, preKeyId, t_preKey);
        }
    }

    @Override
    public void removeOmemoPreKey(OmemoDevice userDevice, int preKeyId) {
        getCache(userDevice).preKeys.remove(preKeyId);
        if (persistent != null) {
            persistent.removeOmemoPreKey(userDevice, preKeyId);
        }
    }

    @Override
    public TreeMap<Integer, T_PreKey> loadOmemoPreKeys(OmemoDevice userDevice) throws IOException {
        TreeMap<Integer, T_PreKey> preKeys = getCache(userDevice).preKeys;

        if (preKeys.isEmpty() && persistent != null) {
            preKeys.putAll(persistent.loadOmemoPreKeys(userDevice));
        }

        return new TreeMap<>(preKeys);
    }

    @Override
    public T_SigPreKey loadOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) throws IOException {
        T_SigPreKey sigPreKey = getCache(userDevice).signedPreKeys.get(signedPreKeyId);

        if (sigPreKey == null && persistent != null) {
            sigPreKey = persistent.loadOmemoSignedPreKey(userDevice, signedPreKeyId);
            if (sigPreKey != null) {
                getCache(userDevice).signedPreKeys.put(signedPreKeyId, sigPreKey);
            }
        }

        return sigPreKey;
    }

    @Override
    public TreeMap<Integer, T_SigPreKey> loadOmemoSignedPreKeys(OmemoDevice userDevice) throws IOException {
        TreeMap<Integer, T_SigPreKey> sigPreKeys = getCache(userDevice).signedPreKeys;

        if (sigPreKeys.isEmpty() && persistent != null) {
            sigPreKeys.putAll(persistent.loadOmemoSignedPreKeys(userDevice));
        }

        return new TreeMap<>(sigPreKeys);
    }

    @Override
    public void storeOmemoSignedPreKey(OmemoDevice userDevice,
                                       int signedPreKeyId,
                                       T_SigPreKey signedPreKey) throws IOException {
        getCache(userDevice).signedPreKeys.put(signedPreKeyId, signedPreKey);
        if (persistent != null) {
            persistent.storeOmemoSignedPreKey(userDevice, signedPreKeyId, signedPreKey);
        }
    }

    @Override
    public void removeOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) {
        getCache(userDevice).signedPreKeys.remove(signedPreKeyId);
        if (persistent != null) {
            persistent.removeOmemoSignedPreKey(userDevice, signedPreKeyId);
        }
    }

    @Override
    public T_Sess loadRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
        HashMap<Integer, T_Sess> contactSessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
        if (contactSessions == null) {
            contactSessions = new HashMap<>();
            getCache(userDevice).sessions.put(contactsDevice.getJid(), contactSessions);
        }

        T_Sess session = contactSessions.get(contactsDevice.getDeviceId());
        if (session == null && persistent != null) {
            session = persistent.loadRawSession(userDevice, contactsDevice);
            if (session != null) {
                contactSessions.put(contactsDevice.getDeviceId(), session);
            }
        }

        return session;
    }

    @Override
    public HashMap<Integer, T_Sess> loadAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) throws IOException {
        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contact);
        if (sessions == null) {
            sessions = new HashMap<>();
            getCache(userDevice).sessions.put(contact, sessions);
        }

        if (sessions.isEmpty() && persistent != null) {
            sessions.putAll(persistent.loadAllRawSessionsOf(userDevice, contact));
        }

        return new HashMap<>(sessions);
    }

    @Override
    public void storeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevicece, T_Sess session) throws IOException {
        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevicece.getJid());
        if (sessions == null) {
            sessions = new HashMap<>();
            getCache(userDevice).sessions.put(contactsDevicece.getJid(), sessions);
        }

        sessions.put(contactsDevicece.getDeviceId(), session);
        if (persistent != null) {
            persistent.storeRawSession(userDevice, contactsDevicece, session);
        }
    }

    @Override
    public void removeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
        if (sessions != null) {
            sessions.remove(contactsDevice.getDeviceId());
        }

        if (persistent != null) {
            persistent.removeRawSession(userDevice, contactsDevice);
        }
    }

    @Override
    public void removeAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) {
        getCache(userDevice).sessions.remove(contact);
        if (persistent != null) {
            persistent.removeAllRawSessionsOf(userDevice, contact);
        }
    }

    @Override
    public boolean containsRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());

        return (sessions != null && sessions.get(contactsDevice.getDeviceId()) != null) ||
                (persistent != null && persistent.containsRawSession(userDevice, contactsDevice));
    }

    @Override
    public OmemoCachedDeviceList loadCachedDeviceList(OmemoDevice userDevice, BareJid contact) throws IOException {
        OmemoCachedDeviceList list = getCache(userDevice).deviceLists.get(contact);

        if (list == null && persistent != null) {
            list = persistent.loadCachedDeviceList(userDevice, contact);
            if (list != null) {
                getCache(userDevice).deviceLists.put(contact, list);
            }
        }

        return list == null ? new OmemoCachedDeviceList() : new OmemoCachedDeviceList(list);
    }

    @Override
    public void storeCachedDeviceList(OmemoDevice userDevice,
                                      BareJid contact,
                                      OmemoCachedDeviceList deviceList) throws IOException {
        getCache(userDevice).deviceLists.put(contact, new OmemoCachedDeviceList(deviceList));

        if (persistent != null) {
            persistent.storeCachedDeviceList(userDevice, contact, deviceList);
        }
    }

    @Override
    public void purgeOwnDeviceKeys(OmemoDevice userDevice) {
        caches.remove(userDevice);

        if (persistent != null) {
            persistent.purgeOwnDeviceKeys(userDevice);
        }
    }

    @Override
    public OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle>
    keyUtil() {
        if (persistent != null) {
            return persistent.keyUtil();
        } else {
            return keyUtil;
        }
    }

    /**
     * Return the {@link KeyCache} object of an {@link OmemoManager}.
     *
     * @param device OMEMO device of which we want to have the cache.
     * @return key cache of the device
     */
    private KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> getCache(OmemoDevice device) {
        KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> cache = caches.get(device);
        if (cache == null) {
            cache = new KeyCache<>();
            caches.put(device, cache);
        }
        return cache;
    }

    /**
     * Cache that stores values for an {@link OmemoManager}.
     *
     * @param <T_IdKeyPair> type of the identity key pair
     * @param <T_IdKey> type of the public identity key
     * @param <T_PreKey> type of a public preKey
     * @param <T_SigPreKey> type of the public signed preKey
     * @param <T_Sess> type of the OMEMO session
     */
    private static class KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> {
        private T_IdKeyPair identityKeyPair;
        private final TreeMap<Integer, T_PreKey> preKeys = new TreeMap<>();
        private final TreeMap<Integer, T_SigPreKey> signedPreKeys = new TreeMap<>();
        private final HashMap<BareJid, HashMap<Integer, T_Sess>> sessions = new HashMap<>();
        private final HashMap<OmemoDevice, T_IdKey> identityKeys = new HashMap<>();
        private final HashMap<OmemoDevice, Date> lastMessagesDates = new HashMap<>();
        private final HashMap<OmemoDevice, Date> lastDeviceIdPublicationDates = new HashMap<>();
        private final HashMap<BareJid, OmemoCachedDeviceList> deviceLists = new HashMap<>();
        private Date lastRenewalDate = null;
        private final HashMap<OmemoDevice, Integer> messageCounters = new HashMap<>();
    }
}