001/**
002 *
003 * Copyright 2017 Paul Schaub
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.jivesoftware.smackx.omemo;
018
019import java.io.IOException;
020import java.util.Date;
021import java.util.HashMap;
022import java.util.SortedSet;
023import java.util.TreeMap;
024import java.util.TreeSet;
025
026import org.jivesoftware.smackx.omemo.exceptions.CorruptedOmemoKeyException;
027import org.jivesoftware.smackx.omemo.internal.OmemoCachedDeviceList;
028import org.jivesoftware.smackx.omemo.internal.OmemoDevice;
029import org.jivesoftware.smackx.omemo.util.OmemoKeyUtil;
030
031import org.jxmpp.jid.BareJid;
032
033/**
034 * This class implements the Proxy Pattern in order to wrap an OmemoStore with a caching layer.
035 * This reduces access to the underlying storage layer (eg. database, filesystem) by only accessing it for
036 * missing/updated values.
037 *
038 * Alternatively this implementation can be used as an ephemeral keystore without a persisting backend.
039 *
040 * @param <T_IdKeyPair> the type of the id key pair.
041 * @param <T_IdKey> the type of the id key.
042 * @param <T_PreKey> the prekey type
043 * @param <T_SigPreKey> the signed prekey type.
044 * @param <T_Sess> the session type.
045 * @param <T_Addr> the address type.
046 * @param <T_ECPub> the EC pub type.
047 * @param <T_Bundle> the bundle type.
048 * @param <T_Ciph> the cipher type.
049 */
050public class CachingOmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph>
051        extends OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> {
052
053    private final HashMap<OmemoDevice, KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess>> caches = new HashMap<>();
054    private final OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> persistent;
055    private final OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil;
056
057    public CachingOmemoStore(OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle> keyUtil) {
058        if (keyUtil == null) {
059            throw new IllegalArgumentException("KeyUtil MUST NOT be null!");
060        }
061        this.keyUtil = keyUtil;
062        persistent = null;
063    }
064
065    public CachingOmemoStore(OmemoStore<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_Addr, T_ECPub, T_Bundle, T_Ciph> wrappedStore) {
066        if (wrappedStore == null) {
067            throw new NullPointerException("Wrapped OmemoStore MUST NOT be null!");
068        }
069        this.keyUtil = null;
070        persistent = wrappedStore;
071    }
072
073    @Override
074    public SortedSet<Integer> localDeviceIdsOf(BareJid localUser) {
075        if (persistent != null) {
076            return persistent.localDeviceIdsOf(localUser);
077        } else {
078            SortedSet<Integer> deviceIds = new TreeSet<>();
079            for (OmemoDevice device : caches.keySet()) {
080                if (device.getJid().equals(localUser)) {
081                    deviceIds.add(device.getDeviceId());
082                }
083            }
084            return deviceIds;
085        }
086    }
087
088    @Override
089    public T_IdKeyPair loadOmemoIdentityKeyPair(OmemoDevice userDevice)
090            throws CorruptedOmemoKeyException, IOException {
091        T_IdKeyPair pair = getCache(userDevice).identityKeyPair;
092
093        if (pair == null && persistent != null) {
094            pair = persistent.loadOmemoIdentityKeyPair(userDevice);
095            if (pair != null) {
096                getCache(userDevice).identityKeyPair = pair;
097            }
098        }
099
100        return pair;
101    }
102
103    @Override
104    public void storeOmemoIdentityKeyPair(OmemoDevice userDevice, T_IdKeyPair identityKeyPair) throws IOException {
105        getCache(userDevice).identityKeyPair = identityKeyPair;
106        if (persistent != null) {
107            persistent.storeOmemoIdentityKeyPair(userDevice, identityKeyPair);
108        }
109    }
110
111    @Override
112    public void removeOmemoIdentityKeyPair(OmemoDevice userDevice) {
113        getCache(userDevice).identityKeyPair = null;
114        if (persistent != null) {
115            persistent.removeOmemoIdentityKeyPair(userDevice);
116        }
117    }
118
119    @Override
120    public T_IdKey loadOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice)
121            throws CorruptedOmemoKeyException, IOException {
122        T_IdKey idKey = getCache(userDevice).identityKeys.get(contactsDevice);
123
124        if (idKey == null && persistent != null) {
125            idKey = persistent.loadOmemoIdentityKey(userDevice, contactsDevice);
126            if (idKey != null) {
127                getCache(userDevice).identityKeys.put(contactsDevice, idKey);
128            }
129        }
130
131        return idKey;
132    }
133
134    @Override
135    public void storeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice device, T_IdKey t_idKey) throws IOException {
136        getCache(userDevice).identityKeys.put(device, t_idKey);
137        if (persistent != null) {
138            persistent.storeOmemoIdentityKey(userDevice, device, t_idKey);
139        }
140    }
141
142    @Override
143    public void removeOmemoIdentityKey(OmemoDevice userDevice, OmemoDevice contactsDevice) {
144        getCache(userDevice).identityKeys.remove(contactsDevice);
145        if (persistent != null) {
146            persistent.removeOmemoIdentityKey(userDevice, contactsDevice);
147        }
148    }
149
150    @Override
151    public void storeOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice, int counter) throws IOException {
152        getCache(userDevice).messageCounters.put(contactsDevice, counter);
153        if (persistent != null) {
154            persistent.storeOmemoMessageCounter(userDevice, contactsDevice, counter);
155        }
156    }
157
158    @Override
159    public int loadOmemoMessageCounter(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
160        Integer counter = getCache(userDevice).messageCounters.get(contactsDevice);
161        if (counter == null && persistent != null) {
162            counter = persistent.loadOmemoMessageCounter(userDevice, contactsDevice);
163        }
164
165        if (counter == null) {
166            counter = 0;
167        }
168
169        getCache(userDevice).messageCounters.put(contactsDevice, counter);
170
171        return counter;
172    }
173
174    @Override
175    public void setDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from, Date date) throws IOException {
176        getCache(userDevice).lastMessagesDates.put(from, date);
177        if (persistent != null) {
178            persistent.setDateOfLastReceivedMessage(userDevice, from, date);
179        }
180    }
181
182    @Override
183    public Date getDateOfLastReceivedMessage(OmemoDevice userDevice, OmemoDevice from) throws IOException {
184        Date last = getCache(userDevice).lastMessagesDates.get(from);
185
186        if (last == null && persistent != null) {
187            last = persistent.getDateOfLastReceivedMessage(userDevice, from);
188            if (last != null) {
189                getCache(userDevice).lastMessagesDates.put(from, last);
190            }
191        }
192
193        return last;
194    }
195
196    @Override
197    public void setDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice, Date date) throws IOException {
198        getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, date);
199        if (persistent != null) {
200            persistent.setDateOfLastReceivedMessage(userDevice, contactsDevice, date);
201        }
202    }
203
204    @Override
205    public Date getDateOfLastDeviceIdPublication(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
206        Date last = getCache(userDevice).lastDeviceIdPublicationDates.get(contactsDevice);
207
208        if (last == null && persistent != null) {
209            last = persistent.getDateOfLastDeviceIdPublication(userDevice, contactsDevice);
210            if (last != null) {
211                getCache(userDevice).lastDeviceIdPublicationDates.put(contactsDevice, last);
212            }
213        }
214
215        return last;
216    }
217
218    @Override
219    public void setDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice, Date date) throws IOException {
220        getCache(userDevice).lastRenewalDate = date;
221        if (persistent != null) {
222            persistent.setDateOfLastSignedPreKeyRenewal(userDevice, date);
223        }
224    }
225
226    @Override
227    public Date getDateOfLastSignedPreKeyRenewal(OmemoDevice userDevice) throws IOException {
228        Date lastRenewal = getCache(userDevice).lastRenewalDate;
229
230        if (lastRenewal == null && persistent != null) {
231            lastRenewal = persistent.getDateOfLastSignedPreKeyRenewal(userDevice);
232            if (lastRenewal != null) {
233                getCache(userDevice).lastRenewalDate = lastRenewal;
234            }
235        }
236
237        return lastRenewal;
238    }
239
240    @Override
241    public T_PreKey loadOmemoPreKey(OmemoDevice userDevice, int preKeyId) throws IOException {
242        T_PreKey preKey = getCache(userDevice).preKeys.get(preKeyId);
243
244        if (preKey == null && persistent != null) {
245            preKey = persistent.loadOmemoPreKey(userDevice, preKeyId);
246            if (preKey != null) {
247                getCache(userDevice).preKeys.put(preKeyId, preKey);
248            }
249        }
250
251        return preKey;
252    }
253
254    @Override
255    public void storeOmemoPreKey(OmemoDevice userDevice, int preKeyId, T_PreKey t_preKey) throws IOException {
256        getCache(userDevice).preKeys.put(preKeyId, t_preKey);
257        if (persistent != null) {
258            persistent.storeOmemoPreKey(userDevice, preKeyId, t_preKey);
259        }
260    }
261
262    @Override
263    public void removeOmemoPreKey(OmemoDevice userDevice, int preKeyId) {
264        getCache(userDevice).preKeys.remove(preKeyId);
265        if (persistent != null) {
266            persistent.removeOmemoPreKey(userDevice, preKeyId);
267        }
268    }
269
270    @Override
271    public TreeMap<Integer, T_PreKey> loadOmemoPreKeys(OmemoDevice userDevice) throws IOException {
272        TreeMap<Integer, T_PreKey> preKeys = getCache(userDevice).preKeys;
273
274        if (preKeys.isEmpty() && persistent != null) {
275            preKeys.putAll(persistent.loadOmemoPreKeys(userDevice));
276        }
277
278        return new TreeMap<>(preKeys);
279    }
280
281    @Override
282    public T_SigPreKey loadOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) throws IOException {
283        T_SigPreKey sigPreKey = getCache(userDevice).signedPreKeys.get(signedPreKeyId);
284
285        if (sigPreKey == null && persistent != null) {
286            sigPreKey = persistent.loadOmemoSignedPreKey(userDevice, signedPreKeyId);
287            if (sigPreKey != null) {
288                getCache(userDevice).signedPreKeys.put(signedPreKeyId, sigPreKey);
289            }
290        }
291
292        return sigPreKey;
293    }
294
295    @Override
296    public TreeMap<Integer, T_SigPreKey> loadOmemoSignedPreKeys(OmemoDevice userDevice) throws IOException {
297        TreeMap<Integer, T_SigPreKey> sigPreKeys = getCache(userDevice).signedPreKeys;
298
299        if (sigPreKeys.isEmpty() && persistent != null) {
300            sigPreKeys.putAll(persistent.loadOmemoSignedPreKeys(userDevice));
301        }
302
303        return new TreeMap<>(sigPreKeys);
304    }
305
306    @Override
307    public void storeOmemoSignedPreKey(OmemoDevice userDevice,
308                                       int signedPreKeyId,
309                                       T_SigPreKey signedPreKey) throws IOException {
310        getCache(userDevice).signedPreKeys.put(signedPreKeyId, signedPreKey);
311        if (persistent != null) {
312            persistent.storeOmemoSignedPreKey(userDevice, signedPreKeyId, signedPreKey);
313        }
314    }
315
316    @Override
317    public void removeOmemoSignedPreKey(OmemoDevice userDevice, int signedPreKeyId) {
318        getCache(userDevice).signedPreKeys.remove(signedPreKeyId);
319        if (persistent != null) {
320            persistent.removeOmemoSignedPreKey(userDevice, signedPreKeyId);
321        }
322    }
323
324    @Override
325    public T_Sess loadRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) throws IOException {
326        HashMap<Integer, T_Sess> contactSessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
327        if (contactSessions == null) {
328            contactSessions = new HashMap<>();
329            getCache(userDevice).sessions.put(contactsDevice.getJid(), contactSessions);
330        }
331
332        T_Sess session = contactSessions.get(contactsDevice.getDeviceId());
333        if (session == null && persistent != null) {
334            session = persistent.loadRawSession(userDevice, contactsDevice);
335            if (session != null) {
336                contactSessions.put(contactsDevice.getDeviceId(), session);
337            }
338        }
339
340        return session;
341    }
342
343    @Override
344    public HashMap<Integer, T_Sess> loadAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) throws IOException {
345        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contact);
346        if (sessions == null) {
347            sessions = new HashMap<>();
348            getCache(userDevice).sessions.put(contact, sessions);
349        }
350
351        if (sessions.isEmpty() && persistent != null) {
352            sessions.putAll(persistent.loadAllRawSessionsOf(userDevice, contact));
353        }
354
355        return new HashMap<>(sessions);
356    }
357
358    @Override
359    public void storeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevicece, T_Sess session) throws IOException {
360        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevicece.getJid());
361        if (sessions == null) {
362            sessions = new HashMap<>();
363            getCache(userDevice).sessions.put(contactsDevicece.getJid(), sessions);
364        }
365
366        sessions.put(contactsDevicece.getDeviceId(), session);
367        if (persistent != null) {
368            persistent.storeRawSession(userDevice, contactsDevicece, session);
369        }
370    }
371
372    @Override
373    public void removeRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
374        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
375        if (sessions != null) {
376            sessions.remove(contactsDevice.getDeviceId());
377        }
378
379        if (persistent != null) {
380            persistent.removeRawSession(userDevice, contactsDevice);
381        }
382    }
383
384    @Override
385    public void removeAllRawSessionsOf(OmemoDevice userDevice, BareJid contact) {
386        getCache(userDevice).sessions.remove(contact);
387        if (persistent != null) {
388            persistent.removeAllRawSessionsOf(userDevice, contact);
389        }
390    }
391
392    @Override
393    public boolean containsRawSession(OmemoDevice userDevice, OmemoDevice contactsDevice) {
394        HashMap<Integer, T_Sess> sessions = getCache(userDevice).sessions.get(contactsDevice.getJid());
395
396        return (sessions != null && sessions.get(contactsDevice.getDeviceId()) != null) ||
397                (persistent != null && persistent.containsRawSession(userDevice, contactsDevice));
398    }
399
400    @Override
401    public OmemoCachedDeviceList loadCachedDeviceList(OmemoDevice userDevice, BareJid contact) throws IOException {
402        OmemoCachedDeviceList list = getCache(userDevice).deviceLists.get(contact);
403
404        if (list == null && persistent != null) {
405            list = persistent.loadCachedDeviceList(userDevice, contact);
406            if (list != null) {
407                getCache(userDevice).deviceLists.put(contact, list);
408            }
409        }
410
411        return list == null ? new OmemoCachedDeviceList() : new OmemoCachedDeviceList(list);
412    }
413
414    @Override
415    public void storeCachedDeviceList(OmemoDevice userDevice,
416                                      BareJid contact,
417                                      OmemoCachedDeviceList deviceList) throws IOException {
418        getCache(userDevice).deviceLists.put(contact, new OmemoCachedDeviceList(deviceList));
419
420        if (persistent != null) {
421            persistent.storeCachedDeviceList(userDevice, contact, deviceList);
422        }
423    }
424
425    @Override
426    public void purgeOwnDeviceKeys(OmemoDevice userDevice) {
427        caches.remove(userDevice);
428
429        if (persistent != null) {
430            persistent.purgeOwnDeviceKeys(userDevice);
431        }
432    }
433
434    @Override
435    public OmemoKeyUtil<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess, T_ECPub, T_Bundle>
436    keyUtil() {
437        if (persistent != null) {
438            return persistent.keyUtil();
439        } else {
440            return keyUtil;
441        }
442    }
443
444    /**
445     * Return the {@link KeyCache} object of an {@link OmemoManager}.
446     *
447     * @param device OMEMO device of which we want to have the cache.
448     * @return key cache of the device
449     */
450    private KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> getCache(OmemoDevice device) {
451        KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> cache = caches.get(device);
452        if (cache == null) {
453            cache = new KeyCache<>();
454            caches.put(device, cache);
455        }
456        return cache;
457    }
458
459    /**
460     * Cache that stores values for an {@link OmemoManager}.
461     *
462     * @param <T_IdKeyPair> type of the identity key pair
463     * @param <T_IdKey> type of the public identity key
464     * @param <T_PreKey> type of a public preKey
465     * @param <T_SigPreKey> type of the public signed preKey
466     * @param <T_Sess> type of the OMEMO session
467     */
468    private static class KeyCache<T_IdKeyPair, T_IdKey, T_PreKey, T_SigPreKey, T_Sess> {
469        private T_IdKeyPair identityKeyPair;
470        private final TreeMap<Integer, T_PreKey> preKeys = new TreeMap<>();
471        private final TreeMap<Integer, T_SigPreKey> signedPreKeys = new TreeMap<>();
472        private final HashMap<BareJid, HashMap<Integer, T_Sess>> sessions = new HashMap<>();
473        private final HashMap<OmemoDevice, T_IdKey> identityKeys = new HashMap<>();
474        private final HashMap<OmemoDevice, Date> lastMessagesDates = new HashMap<>();
475        private final HashMap<OmemoDevice, Date> lastDeviceIdPublicationDates = new HashMap<>();
476        private final HashMap<BareJid, OmemoCachedDeviceList> deviceLists = new HashMap<>();
477        private Date lastRenewalDate = null;
478        private final HashMap<OmemoDevice, Integer> messageCounters = new HashMap<>();
479    }
480}