diff --git a/src/main/java/io/kamax/mxisd/Mxisd.java b/src/main/java/io/kamax/mxisd/Mxisd.java index 9ada30a..0e54fef 100644 --- a/src/main/java/io/kamax/mxisd/Mxisd.java +++ b/src/main/java/io/kamax/mxisd/Mxisd.java @@ -120,7 +120,7 @@ public class Mxisd { ServiceLoader.load(NotificationHandlerSupplier.class).iterator().forEachRemaining(p -> p.accept(this)); hashManager = new HashManager(); - hashManager.init(cfg.getHashing(), ThreePidProviders.get()); + hashManager.init(cfg.getHashing(), ThreePidProviders.get(), store); idStrategy = new RecursivePriorityLookupStrategy(cfg.getLookup(), ThreePidProviders.get(), bridgeFetcher, hashManager); pMgr = new ProfileManager(ProfileProviders.get(), clientDns, httpClient); diff --git a/src/main/java/io/kamax/mxisd/config/HashingConfig.java b/src/main/java/io/kamax/mxisd/config/HashingConfig.java index ae1882c..8ef749c 100644 --- a/src/main/java/io/kamax/mxisd/config/HashingConfig.java +++ b/src/main/java/io/kamax/mxisd/config/HashingConfig.java @@ -11,6 +11,7 @@ public class HashingConfig { private int pepperLength = 10; private RotationPolicyEnum rotationPolicy; private HashStorageEnum hashStorageType; + private long delay = 10; public void build() { if (isEnabled()) { @@ -18,17 +19,22 @@ public class HashingConfig { LOGGER.info(" Pepper length: {}", getPepperLength()); LOGGER.info(" Rotation policy: {}", getRotationPolicy()); LOGGER.info(" Hash storage type: {}", getHashStorageType()); + if (RotationPolicyEnum.PER_SECONDS == rotationPolicy) { + LOGGER.info(" Rotation delay: {}", delay); + } } else { LOGGER.info("Hash configuration disabled, used only `none` pepper."); } } public enum RotationPolicyEnum { - PER_REQUESTS + PER_REQUESTS, + PER_SECONDS } public enum HashStorageEnum { - IN_MEMORY + IN_MEMORY, + SQL } public boolean isEnabled() { @@ -62,4 +68,12 @@ public class HashingConfig { public void setHashStorageType(HashStorageEnum hashStorageType) { this.hashStorageType = hashStorageType; } + + public long getDelay() { + return delay; + } + + public void setDelay(long delay) { + this.delay = delay; + } } diff --git a/src/main/java/io/kamax/mxisd/hash/HashManager.java b/src/main/java/io/kamax/mxisd/hash/HashManager.java index df89904..7af2e00 100644 --- a/src/main/java/io/kamax/mxisd/hash/HashManager.java +++ b/src/main/java/io/kamax/mxisd/hash/HashManager.java @@ -4,14 +4,16 @@ import io.kamax.mxisd.config.HashingConfig; import io.kamax.mxisd.hash.rotation.HashRotationStrategy; import io.kamax.mxisd.hash.rotation.NoOpRotationStrategy; import io.kamax.mxisd.hash.rotation.RotationPerRequests; +import io.kamax.mxisd.hash.rotation.TimeBasedRotation; import io.kamax.mxisd.hash.storage.EmptyStorage; import io.kamax.mxisd.hash.storage.HashStorage; import io.kamax.mxisd.hash.storage.InMemoryHashStorage; +import io.kamax.mxisd.hash.storage.SqlHashStorage; import io.kamax.mxisd.lookup.provider.IThreePidProvider; +import io.kamax.mxisd.storage.IStorage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -23,10 +25,12 @@ public class HashManager { private HashRotationStrategy rotationStrategy; private HashStorage hashStorage; private HashingConfig config; + private IStorage storage; private AtomicBoolean configured = new AtomicBoolean(false); - public void init(HashingConfig config, List providers) { + public void init(HashingConfig config, List providers, IStorage storage) { this.config = config; + this.storage = storage; initStorage(); hashEngine = new HashEngine(providers, getHashStorage(), config); initRotationStrategy(); @@ -39,6 +43,9 @@ public class HashManager { case IN_MEMORY: this.hashStorage = new InMemoryHashStorage(); break; + case SQL: + this.hashStorage = new SqlHashStorage(storage); + break; default: throw new IllegalArgumentException("Unknown storage type: " + config.getHashStorageType()); } @@ -53,6 +60,9 @@ public class HashManager { case PER_REQUESTS: this.rotationStrategy = new RotationPerRequests(); break; + case PER_SECONDS: + this.rotationStrategy = new TimeBasedRotation(config.getDelay()); + break; default: throw new IllegalArgumentException("Unknown rotation type: " + config.getHashStorageType()); } diff --git a/src/main/java/io/kamax/mxisd/hash/rotation/TimeBasedRotation.java b/src/main/java/io/kamax/mxisd/hash/rotation/TimeBasedRotation.java new file mode 100644 index 0000000..92032cf --- /dev/null +++ b/src/main/java/io/kamax/mxisd/hash/rotation/TimeBasedRotation.java @@ -0,0 +1,34 @@ +package io.kamax.mxisd.hash.rotation; + +import io.kamax.mxisd.hash.HashEngine; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +public class TimeBasedRotation implements HashRotationStrategy { + + private final long delay; + private HashEngine hashEngine; + private final ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor(); + + public TimeBasedRotation(long delay) { + this.delay = delay; + } + + @Override + public void register(HashEngine hashEngine) { + this.hashEngine = hashEngine; + Runtime.getRuntime().addShutdownHook(new Thread(executorService::shutdown)); + executorService.scheduleWithFixedDelay(this::trigger, 0, delay, TimeUnit.SECONDS); + } + + @Override + public HashEngine getHashEngine() { + return hashEngine; + } + + @Override + public void newRequest() { + } +} diff --git a/src/main/java/io/kamax/mxisd/hash/storage/SqlHashStorage.java b/src/main/java/io/kamax/mxisd/hash/storage/SqlHashStorage.java new file mode 100644 index 0000000..43123d1 --- /dev/null +++ b/src/main/java/io/kamax/mxisd/hash/storage/SqlHashStorage.java @@ -0,0 +1,31 @@ +package io.kamax.mxisd.hash.storage; + +import io.kamax.mxisd.lookup.ThreePidMapping; +import io.kamax.mxisd.storage.IStorage; +import org.apache.commons.lang3.tuple.Pair; + +import java.util.Collection; + +public class SqlHashStorage implements HashStorage { + + private final IStorage storage; + + public SqlHashStorage(IStorage storage) { + this.storage = storage; + } + + @Override + public Collection> find(Iterable hashes) { + return storage.findHashes(hashes); + } + + @Override + public void add(ThreePidMapping pidMapping, String hash) { + storage.addHash(pidMapping.getMxid(), pidMapping.getMedium(), pidMapping.getValue(), hash); + } + + @Override + public void clear() { + storage.clearHashes(); + } +} diff --git a/src/main/java/io/kamax/mxisd/storage/IStorage.java b/src/main/java/io/kamax/mxisd/storage/IStorage.java index 36123d2..f0a1904 100644 --- a/src/main/java/io/kamax/mxisd/storage/IStorage.java +++ b/src/main/java/io/kamax/mxisd/storage/IStorage.java @@ -23,10 +23,12 @@ package io.kamax.mxisd.storage; import io.kamax.matrix.ThreePid; import io.kamax.mxisd.config.PolicyConfig; import io.kamax.mxisd.invitation.IThreePidInviteReply; +import io.kamax.mxisd.lookup.ThreePidMapping; import io.kamax.mxisd.storage.dao.IThreePidSessionDao; import io.kamax.mxisd.storage.ormlite.dao.ASTransactionDao; import io.kamax.mxisd.storage.ormlite.dao.AccountDao; import io.kamax.mxisd.storage.ormlite.dao.ThreePidInviteIO; +import org.apache.commons.lang3.tuple.Pair; import java.time.Instant; import java.util.Collection; @@ -66,4 +68,10 @@ public interface IStorage { void deleteAccepts(String token); boolean isTermAccepted(String token, List policies); + + void clearHashes(); + + void addHash(String mxid, String medium, String address, String hash); + + Collection> findHashes(Iterable hashes); } diff --git a/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java b/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java index da91f8a..ebb021c 100644 --- a/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java +++ b/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java @@ -24,6 +24,7 @@ import com.j256.ormlite.dao.CloseableWrappedIterable; import com.j256.ormlite.dao.Dao; import com.j256.ormlite.dao.DaoManager; import com.j256.ormlite.jdbc.JdbcConnectionSource; +import com.j256.ormlite.stmt.QueryBuilder; import com.j256.ormlite.support.ConnectionSource; import com.j256.ormlite.table.TableUtils; import io.kamax.matrix.ThreePid; @@ -33,15 +34,18 @@ import io.kamax.mxisd.exception.ConfigurationException; import io.kamax.mxisd.exception.InternalServerError; import io.kamax.mxisd.exception.InvalidCredentialsException; import io.kamax.mxisd.invitation.IThreePidInviteReply; +import io.kamax.mxisd.lookup.ThreePidMapping; import io.kamax.mxisd.storage.IStorage; import io.kamax.mxisd.storage.dao.IThreePidSessionDao; import io.kamax.mxisd.storage.ormlite.dao.ASTransactionDao; import io.kamax.mxisd.storage.ormlite.dao.AccountDao; +import io.kamax.mxisd.storage.ormlite.dao.HashDao; import io.kamax.mxisd.storage.ormlite.dao.HistoricalThreePidInviteIO; import io.kamax.mxisd.storage.ormlite.dao.AcceptedDao; import io.kamax.mxisd.storage.ormlite.dao.ThreePidInviteIO; import io.kamax.mxisd.storage.ormlite.dao.ThreePidSessionDao; import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import java.io.IOException; import java.sql.SQLException; @@ -51,6 +55,7 @@ import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.UUID; +import java.util.stream.Collectors; public class OrmLiteSqlStorage implements IStorage { @@ -74,6 +79,7 @@ public class OrmLiteSqlStorage implements IStorage { private Dao asTxnDao; private Dao accountDao; private Dao acceptedDao; + private Dao hashDao; public OrmLiteSqlStorage(MxisdConfig cfg) { this(cfg.getStorage().getBackend(), cfg.getStorage().getProvider().getSqlite().getDatabase()); @@ -96,6 +102,7 @@ public class OrmLiteSqlStorage implements IStorage { asTxnDao = createDaoAndTable(connPool, ASTransactionDao.class); accountDao = createDaoAndTable(connPool, AccountDao.class); acceptedDao = createDaoAndTable(connPool, AcceptedDao.class); + hashDao = createDaoAndTable(connPool, HashDao.class); }); } @@ -319,4 +326,33 @@ public class OrmLiteSqlStorage implements IStorage { return false; }); } + + @Override + public void clearHashes() { + withCatcher(() -> { + List allHashes = hashDao.queryForAll(); + int deleted = hashDao.delete(allHashes); + if (deleted != allHashes.size()) { + throw new RuntimeException("Not all hashes deleted: " + deleted); + } + }); + } + + @Override + public void addHash(String mxid, String medium, String address, String hash) { + withCatcher(() -> { + hashDao.create(new HashDao(mxid, medium, address, hash)); + }); + } + + @Override + public Collection> findHashes(Iterable hashes) { + return withCatcher(() -> { + QueryBuilder builder = hashDao.queryBuilder(); + builder.where().in("hash", hashes); + return hashDao.query(builder.prepare()).stream() + .map(dao -> Pair.of(dao.getHash(), new ThreePidMapping(dao.getMedium(), dao.getAddress(), dao.getMxid()))).collect( + Collectors.toList()); + }); + } } diff --git a/src/main/java/io/kamax/mxisd/storage/ormlite/dao/HashDao.java b/src/main/java/io/kamax/mxisd/storage/ormlite/dao/HashDao.java new file mode 100644 index 0000000..3714178 --- /dev/null +++ b/src/main/java/io/kamax/mxisd/storage/ormlite/dao/HashDao.java @@ -0,0 +1,62 @@ +package io.kamax.mxisd.storage.ormlite.dao; + +import com.j256.ormlite.field.DatabaseField; +import com.j256.ormlite.table.DatabaseTable; + +@DatabaseTable(tableName = "hashes") +public class HashDao { + + @DatabaseField(canBeNull = false, id = true) + private String mxid; + + @DatabaseField(canBeNull = false) + private String medium; + + @DatabaseField(canBeNull = false) + private String address; + + @DatabaseField(canBeNull = false) + private String hash; + + public HashDao() { + } + + public HashDao(String mxid, String medium, String address, String hash) { + this.mxid = mxid; + this.medium = medium; + this.address = address; + this.hash = hash; + } + + public String getMxid() { + return mxid; + } + + public void setMxid(String mxid) { + this.mxid = mxid; + } + + public String getMedium() { + return medium; + } + + public void setMedium(String medium) { + this.medium = medium; + } + + public String getAddress() { + return address; + } + + public void setAddress(String address) { + this.address = address; + } + + public String getHash() { + return hash; + } + + public void setHash(String hash) { + this.hash = hash; + } +}