001/**
002 *
003 * Copyright 2017 Paul Schaub
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smackx.omemo;
018
019import java.io.IOException;
020import java.util.Date;
021import java.util.HashMap;
022import java.util.Map;
023import java.util.SortedSet;
024import java.util.TreeMap;
025import java.util.TreeSet;
026
027import org.jivesoftware.smackx.omemo.exceptions.CorruptedOmemoKeyException;
028import org.jivesoftware.smackx.omemo.internal.OmemoCachedDeviceList;
029import org.jivesoftware.smackx.omemo.internal.OmemoDevice;
030import org.jivesoftware.smackx.omemo.util.OmemoKeyUtil;
031
032import org.jxmpp.jid.BareJid;
033
034/**
035 * This class implements the Proxy Pattern in order to wrap an OmemoStore with a caching layer.
036 * This reduces access to the underlying storage layer (eg. database, filesystem) by only accessing it for
037 * missing/updated values.
038 *
039 * Alternatively this implementation can be used as an ephemeral keystore without a persisting backend.
040 *
041 * @param <T_IdKeyPair> the type of the id key pair.
042 * @param <T_IdKey> the type of the id key.
043 * @param <T_PreKey> the prekey type
044 * @param <T_SigPreKey> the signed prekey type.
045 * @param <T_Sess> the session type.
046 * @param <T_Addr> the address type.
047 * @param <T_ECPub> the EC pub type.
048 * @param <T_Bundle> the bundle type.
049 * @param <T_Ciph> the cipher type.
050 */
051public class CachingOmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph>
052        extends OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> {
053
054    private final HashMap<OmemoDevice, KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess>> caches = new HashMap<>();
055    private final OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> persistent;
056    private final OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil;
057
058    public CachingOmemoStore(OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil) {
059        if (keyUtil == null) {
060            throw new IllegalArgumentException("KeyUtil MUST NOT be null!");
061        }
062        this.keyUtil = keyUtil;
063        persistent = null;
064    }
065
066    public CachingOmemoStore(OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> wrappedStore) {
067        if (wrappedStore == null) {
068            throw new NullPointerException("Wrapped OmemoStore MUST NOT be null!");
069        }
070        this.keyUtil = null;
071        persistent = wrappedStore;
072    }
073
074    @Override
075    public SortedSet<Integer> localDeviceIdsOf(BareJid localUser) {
076        if (persistent != null) {
077            return persistent.localDeviceIdsOf(localUser);
078        } else {
079            SortedSet<Integer> deviceIds = new TreeSet<>();
080            for (OmemoDevice device : caches.keySet()) {
081                if (device.getJid().equals(localUser)) {
082                    deviceIds.add(device.getDeviceId());
083                }
084            }
085            return deviceIds;
086        }
087    }
088
089    @Override
090    public T_IdKeyPair loadOmemoIdentityKeyPair(OmemoDevice userDevice)
091            throws CorruptedOmemoKeyException, IOException {
092        T_IdKeyPair pair = getCache(userDevice).identityKeyPair;
093
094        if (pair == null && persistent != null) {
095            pair = persistent.loadOmemoIdentityKeyPair(userDevice);
096            if (pair != null) {
097                getCache(userDevice).identityKeyPair = pair;
098            }
099        }
100
101        return pair;
102    }
103
104    @Override
105    public void storeOmemoIdentityKeyPair(OmemoDevice userDevice, T_IdKeyPair identityKeyPair) throws IOException {
106        getCache(userDevice).identityKeyPair = identityKeyPair;
107        if (persistent != null) {
108            persistent.storeOmemoIdentityKeyPair(userDevice, identityKeyPair);
109        }
110    }
111
112    @Override
113    public void removeOmemoIdentityKeyPair(OmemoDevice userDevice) {
114        getCache(userDevice).identityKeyPair = null;
115        if (persistent != null) {
116            persistent.removeOmemoIdentityKeyPair(userDevice);
117        }
118    }
119
120    @Override
121    public T_IdKey loadOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice)
122            throws CorruptedOmemoKeyException, IOException {
123        T_IdKey idKey = getCache(userDevice).identityKeys.get(contactsDevice);
124
125        if (idKey == null && persistent != null) {
126            idKey = persistent.loadOmemoIdentityKey(userDevice, contactsDevice);
127            if (idKey != null) {
128                getCache(userDevice).identityKeys.put(contactsDevice, idKey);
129            }
130        }
131
132        return idKey;
133    }
134
135    @Override
136    public void storeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice device, T_IdKey t_idKey) throws IOException {
137        getCache(userDevice).identityKeys.put(device, t_idKey);
138        if (persistent != null) {
139            persistent.storeOmemoIdentityKey(userDevice, device, t_idKey);
140        }
141    }
142
143    @Override
144    public void removeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice) {
145        getCache(userDevice).identityKeys.remove(contactsDevice);
146        if (persistent != null) {
147            persistent.removeOmemoIdentityKey(userDevice, contactsDevice);
148        }
149    }
150
151    @Override
152    public void storeOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice, int counter) throws IOException {
153        getCache(userDevice).messageCounters.put(contactsDevice, counter);
154        if (persistent != null) {
155            persistent.storeOmemoMessageCounter(userDevice, contactsDevice, counter);
156        }
157    }
158
159    @Override
160    public int loadOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
161        Integer counter = getCache(userDevice).messageCounters.get(contactsDevice);
162        if (counter == null && persistent != null) {
163            counter = persistent.loadOmemoMessageCounter(userDevice, contactsDevice);
164        }
165
166        if (counter == null) {
167            counter = 0;
168        }
169
170        getCache(userDevice).messageCounters.put(contactsDevice, counter);
171
172        return counter;
173    }
174
175    @Override
176    public void setDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from, Date date) throws IOException {
177        getCache(userDevice).lastMessagesDates.put(from, date);
178        if (persistent != null) {
179            persistent.setDateOfLastReceivedMessage(userDevice, from, date);
180        }
181    }
182
183    @Override
184    public Date getDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from) throws IOException {
185        Date last = getCache(userDevice).lastMessagesDates.get(from);
186
187        if (last == null && persistent != null) {
188            last = persistent.getDateOfLastReceivedMessage(userDevice, from);
189            if (last != null) {
190                getCache(userDevice).lastMessagesDates.put(from, last);
191            }
192        }
193
194        return last;
195    }
196
197    @Override
198    public void setDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice, Date date) throws IOException {
199        getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, date);
200        if (persistent != null) {
201            persistent.setDateOfLastReceivedMessage(userDevice, contactsDevice, date);
202        }
203    }
204
205    @Override
206    public Date getDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
207        Date last = getCache(userDevice).lastDeviceIdPublicationDates.get(contactsDevice);
208
209        if (last == null && persistent != null) {
210            last = persistent.getDateOfLastDeviceIdPublication(userDevice, contactsDevice);
211            if (last != null) {
212                getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, last);
213            }
214        }
215
216        return last;
217    }
218
219    @Override
220    public void setDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice, Date date) throws IOException {
221        getCache(userDevice).lastRenewalDate = date;
222        if (persistent != null) {
223            persistent.setDateOfLastSignedPreKeyRenewal(userDevice, date);
224        }
225    }
226
227    @Override
228    public Date getDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice) throws IOException {
229        Date lastRenewal = getCache(userDevice).lastRenewalDate;
230
231        if (lastRenewal == null && persistent != null) {
232            lastRenewal = persistent.getDateOfLastSignedPreKeyRenewal(userDevice);
233            if (lastRenewal != null) {
234                getCache(userDevice).lastRenewalDate = lastRenewal;
235            }
236        }
237
238        return lastRenewal;
239    }
240
241    @Override
242    public T_PreKey loadOmemoPreKey(OmemoDevice userDevice, int preKeyId) throws IOException {
243        T_PreKey preKey = getCache(userDevice).preKeys.get(preKeyId);
244
245        if (preKey == null && persistent != null) {
246            preKey = persistent.loadOmemoPreKey(userDevice, preKeyId);
247            if (preKey != null) {
248                getCache(userDevice).preKeys.put(preKeyId, preKey);
249            }
250        }
251
252        return preKey;
253    }
254
255    @Override
256    public void storeOmemoPreKey(OmemoDevice userDevice, int preKeyId, T_PreKey t_preKey) throws IOException {
257        getCache(userDevice).preKeys.put(preKeyId, t_preKey);
258        if (persistent != null) {
259            persistent.storeOmemoPreKey(userDevice, preKeyId, t_preKey);
260        }
261    }
262
263    @Override
264    public void removeOmemoPreKey(OmemoDevice userDevice, int preKeyId) {
265        getCache(userDevice).preKeys.remove(preKeyId);
266        if (persistent != null) {
267            persistent.removeOmemoPreKey(userDevice, preKeyId);
268        }
269    }
270
271    @Override
272    @SuppressWarnings("NonApiType")
273    public TreeMap<Integer, T_PreKey> loadOmemoPreKeys(OmemoDevice userDevice) throws IOException {
274        Map<Integer, T_PreKey> preKeys = getCache(userDevice).preKeys;
275
276        if (preKeys.isEmpty() && persistent != null) {
277            preKeys.putAll(persistent.loadOmemoPreKeys(userDevice));
278        }
279
280        return new TreeMap<>(preKeys);
281    }
282
283    @Override
284    public T_SigPreKey loadOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) throws IOException {
285        T_SigPreKey sigPreKey = getCache(userDevice).signedPreKeys.get(signedPreKeyId);
286
287        if (sigPreKey == null && persistent != null) {
288            sigPreKey = persistent.loadOmemoSignedPreKey(userDevice, signedPreKeyId);
289            if (sigPreKey != null) {
290                getCache(userDevice).signedPreKeys.put(signedPreKeyId, sigPreKey);
291            }
292        }
293
294        return sigPreKey;
295    }
296
297    @Override
298    @SuppressWarnings("NonApiType")
299    public TreeMap<Integer, T_SigPreKey> loadOmemoSignedPreKeys(OmemoDevice userDevice) throws IOException {
300        Map<Integer, T_SigPreKey> sigPreKeys = getCache(userDevice).signedPreKeys;
301
302        if (sigPreKeys.isEmpty() && persistent != null) {
303            sigPreKeys.putAll(persistent.loadOmemoSignedPreKeys(userDevice));
304        }
305
306        return new TreeMap<>(sigPreKeys);
307    }
308
309    @Override
310    public void storeOmemoSignedPreKey(OmemoDevice userDevice,
311                                       int signedPreKeyId,
312                                       T_SigPreKey signedPreKey) throws IOException {
313        getCache(userDevice).signedPreKeys.put(signedPreKeyId, signedPreKey);
314        if (persistent != null) {
315            persistent.storeOmemoSignedPreKey(userDevice, signedPreKeyId, signedPreKey);
316        }
317    }
318
319    @Override
320    public void removeOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) {
321        getCache(userDevice).signedPreKeys.remove(signedPreKeyId);
322        if (persistent != null) {
323            persistent.removeOmemoSignedPreKey(userDevice, signedPreKeyId);
324        }
325    }
326
327    @Override
328    public T_Sess loadRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
329        HashMap<Integer, T_Sess> contactSessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
330        if (contactSessions == null) {
331            contactSessions = new HashMap<>();
332            getCache(userDevice).sessions.put(contactsDevice.getJid(), contactSessions);
333        }
334
335        T_Sess session = contactSessions.get(contactsDevice.getDeviceId());
336        if (session == null && persistent != null) {
337            session = persistent.loadRawSession(userDevice, contactsDevice);
338            if (session != null) {
339                contactSessions.put(contactsDevice.getDeviceId(), session);
340            }
341        }
342
343        return session;
344    }
345
346    @Override
347    public Map<Integer, T_Sess> loadAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) throws IOException {
348        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contact);
349        if (sessions == null) {
350            sessions = new HashMap<>();
351            getCache(userDevice).sessions.put(contact, sessions);
352        }
353
354        if (sessions.isEmpty() && persistent != null) {
355            sessions.putAll(persistent.loadAllRawSessionsOf(userDevice, contact));
356        }
357
358        return new HashMap<>(sessions);
359    }
360
361    @Override
362    public void storeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevicece, T_Sess session) throws IOException {
363        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevicece.getJid());
364        if (sessions == null) {
365            sessions = new HashMap<>();
366            getCache(userDevice).sessions.put(contactsDevicece.getJid(), sessions);
367        }
368
369        sessions.put(contactsDevicece.getDeviceId(), session);
370        if (persistent != null) {
371            persistent.storeRawSession(userDevice, contactsDevicece, session);
372        }
373    }
374
375    @Override
376    public void removeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
377        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
378        if (sessions != null) {
379            sessions.remove(contactsDevice.getDeviceId());
380        }
381
382        if (persistent != null) {
383            persistent.removeRawSession(userDevice, contactsDevice);
384        }
385    }
386
387    @Override
388    public void removeAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) {
389        getCache(userDevice).sessions.remove(contact);
390        if (persistent != null) {
391            persistent.removeAllRawSessionsOf(userDevice, contact);
392        }
393    }
394
395    @Override
396    public boolean containsRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
397        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
398
399        return (sessions != null && sessions.get(contactsDevice.getDeviceId()) != null) ||
400                (persistent != null && persistent.containsRawSession(userDevice, contactsDevice));
401    }
402
403    @Override
404    public OmemoCachedDeviceList loadCachedDeviceList(OmemoDevice userDevice, BareJid contact) throws IOException {
405        OmemoCachedDeviceList list = getCache(userDevice).deviceLists.get(contact);
406
407        if (list == null && persistent != null) {
408            list = persistent.loadCachedDeviceList(userDevice, contact);
409            if (list != null) {
410                getCache(userDevice).deviceLists.put(contact, list);
411            }
412        }
413
414        return list == null ? new OmemoCachedDeviceList() : new OmemoCachedDeviceList(list);
415    }
416
417    @Override
418    public void storeCachedDeviceList(OmemoDevice userDevice,
419                                      BareJid contact,
420                                      OmemoCachedDeviceList deviceList) throws IOException {
421        getCache(userDevice).deviceLists.put(contact, new OmemoCachedDeviceList(deviceList));
422
423        if (persistent != null) {
424            persistent.storeCachedDeviceList(userDevice, contact, deviceList);
425        }
426    }
427
428    @Override
429    public void purgeOwnDeviceKeys(OmemoDevice userDevice) {
430        caches.remove(userDevice);
431
432        if (persistent != null) {
433            persistent.purgeOwnDeviceKeys(userDevice);
434        }
435    }
436
437    @Override
438    public OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle>
439    keyUtil() {
440        if (persistent != null) {
441            return persistent.keyUtil();
442        } else {
443            return keyUtil;
444        }
445    }
446
447    /**
448     * Return the {@link KeyCache} object of an {@link OmemoManager}.
449     *
450     * @param device OMEMO device of which we want to have the cache.
451     * @return key cache of the device
452     */
453    private KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> getCache(OmemoDevice device) {
454        KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> cache = caches.get(device);
455        if (cache == null) {
456            cache = new KeyCache<>();
457            caches.put(device, cache);
458        }
459        return cache;
460    }
461
462    /**
463     * Cache that stores values for an {@link OmemoManager}.
464     *
465     * @param <T_IdKeyPair> type of the identity key pair
466     * @param <T_IdKey> type of the public identity key
467     * @param <T_PreKey> type of a public preKey
468     * @param <T_SigPreKey> type of the public signed preKey
469     * @param <T_Sess> type of the OMEMO session
470     */
471    private static final class KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> {
472        private T_IdKeyPair identityKeyPair;
473        private final TreeMap<Integer, T_PreKey> preKeys = new TreeMap<>();
474        private final TreeMap<Integer, T_SigPreKey> signedPreKeys = new TreeMap<>();
475        private final HashMap<BareJid, HashMap<Integer, T_Sess>> sessions = new HashMap<>();
476        private final HashMap<OmemoDevice, T_IdKey> identityKeys = new HashMap<>();
477        private final HashMap<OmemoDevice, Date> lastMessagesDates = new HashMap<>();
478        private final HashMap<OmemoDevice, Date> lastDeviceIdPublicationDates = new HashMap<>();
479        private final HashMap<BareJid, OmemoCachedDeviceList> deviceLists = new HashMap<>();
480        private Date lastRenewalDate = null;
481        private final HashMap<OmemoDevice, Integer> messageCounters = new HashMap<>();
482    }
483}