CachingOmemoStore.java

  1. /**
  2.  *
  3.  * Copyright 2017 Paul Schaub
  4.  *
  5.  * Licensed under the Apache License, Version 2.0 (the "License");
  6.  * you may not use this file except in compliance with the License.
  7.  * You may obtain a copy of the License at
  8.  *
  9.  *     http://www.apache.org/licenses/LICENSE-2.0
  10.  *
  11.  * Unless required by applicable law or agreed to in writing, software
  12.  * distributed under the License is distributed on an "AS IS" BASIS,
  13.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14.  * See the License for the specific language governing permissions and
  15.  * limitations under the License.
  16.  */
  17. package org.jivesoftware.smackx.omemo;

  18. import java.io.IOException;
  19. import java.util.Date;
  20. import java.util.HashMap;
  21. import java.util.SortedSet;
  22. import java.util.TreeMap;
  23. import java.util.TreeSet;

  24. import org.jivesoftware.smackx.omemo.exceptions.CorruptedOmemoKeyException;
  25. import org.jivesoftware.smackx.omemo.internal.OmemoCachedDeviceList;
  26. import org.jivesoftware.smackx.omemo.internal.OmemoDevice;
  27. import org.jivesoftware.smackx.omemo.util.OmemoKeyUtil;

  28. import org.jxmpp.jid.BareJid;

  29. /**
  30.  * This class implements the Proxy Pattern in order to wrap an OmemoStore with a caching layer.
  31.  * This reduces access to the underlying storage layer (eg. database, filesystem) by only accessing it for
  32.  * missing/updated values.
  33.  *
  34.  * Alternatively this implementation can be used as an ephemeral keystore without a persisting backend.
  35.  *
  36.  * @param <T_IdKeyPair> the type of the id key pair.
  37.  * @param <T_IdKey> the type of the id key.
  38.  * @param <T_PreKey> the prekey type
  39.  * @param <T_SigPreKey> the signed prekey type.
  40.  * @param <T_Sess> the session type.
  41.  * @param <T_Addr> the address type.
  42.  * @param <T_ECPub> the EC pub type.
  43.  * @param <T_Bundle> the bundle type.
  44.  * @param <T_Ciph> the cipher type.
  45.  */
  46. public class CachingOmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph>
  47.         extends OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> {

  48.     private final HashMap<OmemoDevice, KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess>> caches = new HashMap<>();
  49.     private final OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> persistent;
  50.     private final OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil;

  51.     public CachingOmemoStore(OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil) {
  52.         if (keyUtil == null) {
  53.             throw new IllegalArgumentException("KeyUtil MUST NOT be null!");
  54.         }
  55.         this.keyUtil = keyUtil;
  56.         persistent = null;
  57.     }

  58.     public CachingOmemoStore(OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> wrappedStore) {
  59.         if (wrappedStore == null) {
  60.             throw new NullPointerException("Wrapped OmemoStore MUST NOT be null!");
  61.         }
  62.         this.keyUtil = null;
  63.         persistent = wrappedStore;
  64.     }

  65.     @Override
  66.     public SortedSet<Integer> localDeviceIdsOf(BareJid localUser) {
  67.         if (persistent != null) {
  68.             return persistent.localDeviceIdsOf(localUser);
  69.         } else {
  70.             SortedSet<Integer> deviceIds = new TreeSet<>();
  71.             for (OmemoDevice device : caches.keySet()) {
  72.                 if (device.getJid().equals(localUser)) {
  73.                     deviceIds.add(device.getDeviceId());
  74.                 }
  75.             }
  76.             return deviceIds;
  77.         }
  78.     }

  79.     @Override
  80.     public T_IdKeyPair loadOmemoIdentityKeyPair(OmemoDevice userDevice)
  81.             throws CorruptedOmemoKeyException, IOException {
  82.         T_IdKeyPair pair = getCache(userDevice).identityKeyPair;

  83.         if (pair == null && persistent != null) {
  84.             pair = persistent.loadOmemoIdentityKeyPair(userDevice);
  85.             if (pair != null) {
  86.                 getCache(userDevice).identityKeyPair = pair;
  87.             }
  88.         }

  89.         return pair;
  90.     }

  91.     @Override
  92.     public void storeOmemoIdentityKeyPair(OmemoDevice userDevice, T_IdKeyPair identityKeyPair) throws IOException {
  93.         getCache(userDevice).identityKeyPair = identityKeyPair;
  94.         if (persistent != null) {
  95.             persistent.storeOmemoIdentityKeyPair(userDevice, identityKeyPair);
  96.         }
  97.     }

  98.     @Override
  99.     public void removeOmemoIdentityKeyPair(OmemoDevice userDevice) {
  100.         getCache(userDevice).identityKeyPair = null;
  101.         if (persistent != null) {
  102.             persistent.removeOmemoIdentityKeyPair(userDevice);
  103.         }
  104.     }

  105.     @Override
  106.     public T_IdKey loadOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice)
  107.             throws CorruptedOmemoKeyException, IOException {
  108.         T_IdKey idKey = getCache(userDevice).identityKeys.get(contactsDevice);

  109.         if (idKey == null && persistent != null) {
  110.             idKey = persistent.loadOmemoIdentityKey(userDevice, contactsDevice);
  111.             if (idKey != null) {
  112.                 getCache(userDevice).identityKeys.put(contactsDevice, idKey);
  113.             }
  114.         }

  115.         return idKey;
  116.     }

  117.     @Override
  118.     public void storeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice device, T_IdKey t_idKey) throws IOException {
  119.         getCache(userDevice).identityKeys.put(device, t_idKey);
  120.         if (persistent != null) {
  121.             persistent.storeOmemoIdentityKey(userDevice, device, t_idKey);
  122.         }
  123.     }

  124.     @Override
  125.     public void removeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice) {
  126.         getCache(userDevice).identityKeys.remove(contactsDevice);
  127.         if (persistent != null) {
  128.             persistent.removeOmemoIdentityKey(userDevice, contactsDevice);
  129.         }
  130.     }

  131.     @Override
  132.     public void storeOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice, int counter) throws IOException {
  133.         getCache(userDevice).messageCounters.put(contactsDevice, counter);
  134.         if (persistent != null) {
  135.             persistent.storeOmemoMessageCounter(userDevice, contactsDevice, counter);
  136.         }
  137.     }

  138.     @Override
  139.     public int loadOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
  140.         Integer counter = getCache(userDevice).messageCounters.get(contactsDevice);
  141.         if (counter == null && persistent != null) {
  142.             counter = persistent.loadOmemoMessageCounter(userDevice, contactsDevice);
  143.         }

  144.         if (counter == null) {
  145.             counter = 0;
  146.         }

  147.         getCache(userDevice).messageCounters.put(contactsDevice, counter);

  148.         return counter;
  149.     }

  150.     @Override
  151.     public void setDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from, Date date) throws IOException {
  152.         getCache(userDevice).lastMessagesDates.put(from, date);
  153.         if (persistent != null) {
  154.             persistent.setDateOfLastReceivedMessage(userDevice, from, date);
  155.         }
  156.     }

  157.     @Override
  158.     public Date getDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from) throws IOException {
  159.         Date last = getCache(userDevice).lastMessagesDates.get(from);

  160.         if (last == null && persistent != null) {
  161.             last = persistent.getDateOfLastReceivedMessage(userDevice, from);
  162.             if (last != null) {
  163.                 getCache(userDevice).lastMessagesDates.put(from, last);
  164.             }
  165.         }

  166.         return last;
  167.     }

  168.     @Override
  169.     public void setDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice, Date date) throws IOException {
  170.         getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, date);
  171.         if (persistent != null) {
  172.             persistent.setDateOfLastReceivedMessage(userDevice, contactsDevice, date);
  173.         }
  174.     }

  175.     @Override
  176.     public Date getDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
  177.         Date last = getCache(userDevice).lastDeviceIdPublicationDates.get(contactsDevice);

  178.         if (last == null && persistent != null) {
  179.             last = persistent.getDateOfLastDeviceIdPublication(userDevice, contactsDevice);
  180.             if (last != null) {
  181.                 getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, last);
  182.             }
  183.         }

  184.         return last;
  185.     }

  186.     @Override
  187.     public void setDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice, Date date) throws IOException {
  188.         getCache(userDevice).lastRenewalDate = date;
  189.         if (persistent != null) {
  190.             persistent.setDateOfLastSignedPreKeyRenewal(userDevice, date);
  191.         }
  192.     }

  193.     @Override
  194.     public Date getDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice) throws IOException {
  195.         Date lastRenewal = getCache(userDevice).lastRenewalDate;

  196.         if (lastRenewal == null && persistent != null) {
  197.             lastRenewal = persistent.getDateOfLastSignedPreKeyRenewal(userDevice);
  198.             if (lastRenewal != null) {
  199.                 getCache(userDevice).lastRenewalDate = lastRenewal;
  200.             }
  201.         }

  202.         return lastRenewal;
  203.     }

  204.     @Override
  205.     public T_PreKey loadOmemoPreKey(OmemoDevice userDevice, int preKeyId) throws IOException {
  206.         T_PreKey preKey = getCache(userDevice).preKeys.get(preKeyId);

  207.         if (preKey == null && persistent != null) {
  208.             preKey = persistent.loadOmemoPreKey(userDevice, preKeyId);
  209.             if (preKey != null) {
  210.                 getCache(userDevice).preKeys.put(preKeyId, preKey);
  211.             }
  212.         }

  213.         return preKey;
  214.     }

  215.     @Override
  216.     public void storeOmemoPreKey(OmemoDevice userDevice, int preKeyId, T_PreKey t_preKey) throws IOException {
  217.         getCache(userDevice).preKeys.put(preKeyId, t_preKey);
  218.         if (persistent != null) {
  219.             persistent.storeOmemoPreKey(userDevice, preKeyId, t_preKey);
  220.         }
  221.     }

  222.     @Override
  223.     public void removeOmemoPreKey(OmemoDevice userDevice, int preKeyId) {
  224.         getCache(userDevice).preKeys.remove(preKeyId);
  225.         if (persistent != null) {
  226.             persistent.removeOmemoPreKey(userDevice, preKeyId);
  227.         }
  228.     }

  229.     @Override
  230.     public TreeMap<Integer, T_PreKey> loadOmemoPreKeys(OmemoDevice userDevice) throws IOException {
  231.         TreeMap<Integer, T_PreKey> preKeys = getCache(userDevice).preKeys;

  232.         if (preKeys.isEmpty() && persistent != null) {
  233.             preKeys.putAll(persistent.loadOmemoPreKeys(userDevice));
  234.         }

  235.         return new TreeMap<>(preKeys);
  236.     }

  237.     @Override
  238.     public T_SigPreKey loadOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) throws IOException {
  239.         T_SigPreKey sigPreKey = getCache(userDevice).signedPreKeys.get(signedPreKeyId);

  240.         if (sigPreKey == null && persistent != null) {
  241.             sigPreKey = persistent.loadOmemoSignedPreKey(userDevice, signedPreKeyId);
  242.             if (sigPreKey != null) {
  243.                 getCache(userDevice).signedPreKeys.put(signedPreKeyId, sigPreKey);
  244.             }
  245.         }

  246.         return sigPreKey;
  247.     }

  248.     @Override
  249.     public TreeMap<Integer, T_SigPreKey> loadOmemoSignedPreKeys(OmemoDevice userDevice) throws IOException {
  250.         TreeMap<Integer, T_SigPreKey> sigPreKeys = getCache(userDevice).signedPreKeys;

  251.         if (sigPreKeys.isEmpty() && persistent != null) {
  252.             sigPreKeys.putAll(persistent.loadOmemoSignedPreKeys(userDevice));
  253.         }

  254.         return new TreeMap<>(sigPreKeys);
  255.     }

  256.     @Override
  257.     public void storeOmemoSignedPreKey(OmemoDevice userDevice,
  258.                                        int signedPreKeyId,
  259.                                        T_SigPreKey signedPreKey) throws IOException {
  260.         getCache(userDevice).signedPreKeys.put(signedPreKeyId, signedPreKey);
  261.         if (persistent != null) {
  262.             persistent.storeOmemoSignedPreKey(userDevice, signedPreKeyId, signedPreKey);
  263.         }
  264.     }

  265.     @Override
  266.     public void removeOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) {
  267.         getCache(userDevice).signedPreKeys.remove(signedPreKeyId);
  268.         if (persistent != null) {
  269.             persistent.removeOmemoSignedPreKey(userDevice, signedPreKeyId);
  270.         }
  271.     }

  272.     @Override
  273.     public T_Sess loadRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
  274.         HashMap<Integer, T_Sess> contactSessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
  275.         if (contactSessions == null) {
  276.             contactSessions = new HashMap<>();
  277.             getCache(userDevice).sessions.put(contactsDevice.getJid(), contactSessions);
  278.         }

  279.         T_Sess session = contactSessions.get(contactsDevice.getDeviceId());
  280.         if (session == null && persistent != null) {
  281.             session = persistent.loadRawSession(userDevice, contactsDevice);
  282.             if (session != null) {
  283.                 contactSessions.put(contactsDevice.getDeviceId(), session);
  284.             }
  285.         }

  286.         return session;
  287.     }

  288.     @Override
  289.     public HashMap<Integer, T_Sess> loadAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) throws IOException {
  290.         HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contact);
  291.         if (sessions == null) {
  292.             sessions = new HashMap<>();
  293.             getCache(userDevice).sessions.put(contact, sessions);
  294.         }

  295.         if (sessions.isEmpty() && persistent != null) {
  296.             sessions.putAll(persistent.loadAllRawSessionsOf(userDevice, contact));
  297.         }

  298.         return new HashMap<>(sessions);
  299.     }

  300.     @Override
  301.     public void storeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevicece, T_Sess session) throws IOException {
  302.         HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevicece.getJid());
  303.         if (sessions == null) {
  304.             sessions = new HashMap<>();
  305.             getCache(userDevice).sessions.put(contactsDevicece.getJid(), sessions);
  306.         }

  307.         sessions.put(contactsDevicece.getDeviceId(), session);
  308.         if (persistent != null) {
  309.             persistent.storeRawSession(userDevice, contactsDevicece, session);
  310.         }
  311.     }

  312.     @Override
  313.     public void removeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
  314.         HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
  315.         if (sessions != null) {
  316.             sessions.remove(contactsDevice.getDeviceId());
  317.         }

  318.         if (persistent != null) {
  319.             persistent.removeRawSession(userDevice, contactsDevice);
  320.         }
  321.     }

  322.     @Override
  323.     public void removeAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) {
  324.         getCache(userDevice).sessions.remove(contact);
  325.         if (persistent != null) {
  326.             persistent.removeAllRawSessionsOf(userDevice, contact);
  327.         }
  328.     }

  329.     @Override
  330.     public boolean containsRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
  331.         HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());

  332.         return (sessions != null && sessions.get(contactsDevice.getDeviceId()) != null) ||
  333.                 (persistent != null && persistent.containsRawSession(userDevice, contactsDevice));
  334.     }

  335.     @Override
  336.     public OmemoCachedDeviceList loadCachedDeviceList(OmemoDevice userDevice, BareJid contact) throws IOException {
  337.         OmemoCachedDeviceList list = getCache(userDevice).deviceLists.get(contact);

  338.         if (list == null && persistent != null) {
  339.             list = persistent.loadCachedDeviceList(userDevice, contact);
  340.             if (list != null) {
  341.                 getCache(userDevice).deviceLists.put(contact, list);
  342.             }
  343.         }

  344.         return list == null ? new OmemoCachedDeviceList() : new OmemoCachedDeviceList(list);
  345.     }

  346.     @Override
  347.     public void storeCachedDeviceList(OmemoDevice userDevice,
  348.                                       BareJid contact,
  349.                                       OmemoCachedDeviceList deviceList) throws IOException {
  350.         getCache(userDevice).deviceLists.put(contact, new OmemoCachedDeviceList(deviceList));

  351.         if (persistent != null) {
  352.             persistent.storeCachedDeviceList(userDevice, contact, deviceList);
  353.         }
  354.     }

  355.     @Override
  356.     public void purgeOwnDeviceKeys(OmemoDevice userDevice) {
  357.         caches.remove(userDevice);

  358.         if (persistent != null) {
  359.             persistent.purgeOwnDeviceKeys(userDevice);
  360.         }
  361.     }

  362.     @Override
  363.     public OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle>
  364.     keyUtil() {
  365.         if (persistent != null) {
  366.             return persistent.keyUtil();
  367.         } else {
  368.             return keyUtil;
  369.         }
  370.     }

  371.     /**
  372.      * Return the {@link KeyCache} object of an {@link OmemoManager}.
  373.      *
  374.      * @param device OMEMO device of which we want to have the cache.
  375.      * @return key cache of the device
  376.      */
  377.     private KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> getCache(OmemoDevice device) {
  378.         KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> cache = caches.get(device);
  379.         if (cache == null) {
  380.             cache = new KeyCache<>();
  381.             caches.put(device, cache);
  382.         }
  383.         return cache;
  384.     }

  385.     /**
  386.      * Cache that stores values for an {@link OmemoManager}.
  387.      *
  388.      * @param <T_IdKeyPair> type of the identity key pair
  389.      * @param <T_IdKey> type of the public identity key
  390.      * @param <T_PreKey> type of a public preKey
  391.      * @param <T_SigPreKey> type of the public signed preKey
  392.      * @param <T_Sess> type of the OMEMO session
  393.      */
  394.     private static class KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> {
  395.         private T_IdKeyPair identityKeyPair;
  396.         private final TreeMap<Integer, T_PreKey> preKeys = new TreeMap<>();
  397.         private final TreeMap<Integer, T_SigPreKey> signedPreKeys = new TreeMap<>();
  398.         private final HashMap<BareJid, HashMap<Integer, T_Sess>> sessions = new HashMap<>();
  399.         private final HashMap<OmemoDevice, T_IdKey> identityKeys = new HashMap<>();
  400.         private final HashMap<OmemoDevice, Date> lastMessagesDates = new HashMap<>();
  401.         private final HashMap<OmemoDevice, Date> lastDeviceIdPublicationDates = new HashMap<>();
  402.         private final HashMap<BareJid, OmemoCachedDeviceList> deviceLists = new HashMap<>();
  403.         private Date lastRenewalDate = null;
  404.         private final HashMap<OmemoDevice, Integer> messageCounters = new HashMap<>();
  405.     }
  406. }