diff --git a/src/main/java/io/kamax/mxisd/Mxisd.java b/src/main/java/io/kamax/mxisd/Mxisd.java
index 00022c8..2410f70 100644
--- a/src/main/java/io/kamax/mxisd/Mxisd.java
+++ b/src/main/java/io/kamax/mxisd/Mxisd.java
@@ -115,7 +115,7 @@ public class Mxisd {
invMgr = new InvitationManager(cfg, store, idStrategy, keyMgr, signMgr, fedDns, notifMgr, pMgr);
authMgr = new AuthManager(cfg, AuthProviders.get(), idStrategy, invMgr, clientDns, httpClient);
dirMgr = new DirectoryManager(cfg.getDirectory(), clientDns, httpClient, DirectoryProviders.get());
- regMgr = new RegistrationManager(httpClient, clientDns, idStrategy, invMgr);
+ regMgr = new RegistrationManager(cfg.getRegister(), httpClient, clientDns, invMgr);
asHander = new AppSvcManager(cfg, store, pMgr, notifMgr, synapse);
}
diff --git a/src/main/java/io/kamax/mxisd/config/MxisdConfig.java b/src/main/java/io/kamax/mxisd/config/MxisdConfig.java
index 787171e..3fb5ffa 100644
--- a/src/main/java/io/kamax/mxisd/config/MxisdConfig.java
+++ b/src/main/java/io/kamax/mxisd/config/MxisdConfig.java
@@ -97,6 +97,7 @@ public class MxisdConfig {
private MemoryStoreConfig memory = new MemoryStoreConfig();
private NotificationConfig notification = new NotificationConfig();
private NetIqLdapConfig netiq = new NetIqLdapConfig();
+ private RegisterConfig register = new RegisterConfig();
private ServerConfig server = new ServerConfig();
private SessionConfig session = new SessionConfig();
private StorageConfig storage = new StorageConfig();
@@ -219,6 +220,14 @@ public class MxisdConfig {
this.netiq = netiq;
}
+ public RegisterConfig getRegister() {
+ return register;
+ }
+
+ public void setRegister(RegisterConfig register) {
+ this.register = register;
+ }
+
public ServerConfig getServer() {
return server;
}
@@ -310,6 +319,7 @@ public class MxisdConfig {
getMemory().build();
getNetiq().build();
getNotification().build();
+ getRegister().build();
getRest().build();
getSession().build();
getServer().build();
diff --git a/src/main/java/io/kamax/mxisd/config/RegisterConfig.java b/src/main/java/io/kamax/mxisd/config/RegisterConfig.java
new file mode 100644
index 0000000..515aab4
--- /dev/null
+++ b/src/main/java/io/kamax/mxisd/config/RegisterConfig.java
@@ -0,0 +1,201 @@
+/*
+ * mxisd - Matrix Identity Server Daemon
+ * Copyright (C) 2019 Kamax Sarl
+ *
+ * https://www.kamax.io/
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package io.kamax.mxisd.config;
+
+import io.kamax.matrix.ThreePidMedium;
+import io.kamax.matrix.json.GsonUtil;
+import org.apache.commons.lang3.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+public class RegisterConfig {
+
+ private static final Logger log = LoggerFactory.getLogger(RegisterConfig.class);
+
+ public static class ThreepidPolicyPattern {
+
+ private List blacklist = new ArrayList<>();
+ private List whitelist = new ArrayList<>();
+
+ public List getBlacklist() {
+ return blacklist;
+ }
+
+ public void setBlacklist(List blacklist) {
+ this.blacklist = blacklist;
+ }
+
+ public List getWhitelist() {
+ return whitelist;
+ }
+
+ public void setWhitelist(List whitelist) {
+ this.whitelist = whitelist;
+ }
+
+ }
+
+ public static class EmailPolicy extends ThreepidPolicy {
+
+ private ThreepidPolicyPattern domain = new ThreepidPolicyPattern();
+
+ public ThreepidPolicyPattern getDomain() {
+ return domain;
+ }
+
+ public void setDomain(ThreepidPolicyPattern domain) {
+ this.domain = domain;
+ }
+
+ private List buildPatterns(List domains) {
+ log.info("Building email policy");
+ return domains.stream().map(d -> {
+ if (StringUtils.startsWith(d, "*")) {
+ log.info("Found domain and subdomain policy");
+ d = "(.*)" + d.substring(1);
+ } else if (StringUtils.startsWith(d, ".")) {
+ log.info("Found subdomain-only policy");
+ d = "(.*)" + d;
+ } else {
+ log.info("Found domain-only policy");
+ }
+
+ return "([^@]+)@" + d.replace(".", "\\.");
+ }).collect(Collectors.toList());
+ }
+
+ @Override
+ public void build() {
+ if (Objects.isNull(getDomain())) {
+ return;
+ }
+
+ if (Objects.nonNull(getDomain().getBlacklist())) {
+ if (Objects.isNull(getPattern().getBlacklist())) {
+ getPattern().setBlacklist(new ArrayList<>());
+ }
+
+ List domains = buildPatterns(getDomain().getBlacklist());
+ getPattern().getBlacklist().addAll(domains);
+ }
+
+ if (Objects.nonNull(getDomain().getWhitelist())) {
+ if (Objects.isNull(getPattern().getWhitelist())) {
+ getPattern().setWhitelist(new ArrayList<>());
+ }
+
+ List domains = buildPatterns(getDomain().getWhitelist());
+ getPattern().getWhitelist().addAll(domains);
+ }
+
+ setDomain(null);
+ }
+
+ }
+
+ public static class ThreepidPolicy {
+
+ private ThreepidPolicyPattern pattern = new ThreepidPolicyPattern();
+
+ public ThreepidPolicyPattern getPattern() {
+ return pattern;
+ }
+
+ public void setPattern(ThreepidPolicyPattern pattern) {
+ this.pattern = pattern;
+ }
+
+ public void build() {
+ // no-op
+ }
+
+ }
+
+ public static class Policy {
+
+ private boolean allowed;
+ private boolean invite = true;
+ private Map threepid = new HashMap<>();
+
+ public boolean isAllowed() {
+ return allowed;
+ }
+
+ public void setAllowed(boolean allowed) {
+ this.allowed = allowed;
+ }
+
+ public boolean forInvite() {
+ return invite;
+ }
+
+ public void setInvite(boolean invite) {
+ this.invite = invite;
+ }
+
+ public Map getThreepid() {
+ return threepid;
+ }
+
+ public void setThreepid(Map threepid) {
+ this.threepid = threepid;
+ }
+
+ }
+
+ private Policy policy = new Policy();
+
+ public Policy getPolicy() {
+ return policy;
+ }
+
+ public void setPolicy(Policy policy) {
+ this.policy = policy;
+ }
+
+ public void build() {
+ log.info("--- Registration config ---");
+
+ log.info("Before Build");
+ log.info(GsonUtil.getPrettyForLog(this));
+
+ new HashMap<>(getPolicy().getThreepid()).forEach((medium, policy) -> {
+ if (ThreePidMedium.Email.is(medium)) {
+ EmailPolicy pPolicy = GsonUtil.get().fromJson(GsonUtil.get().toJson(policy), EmailPolicy.class);
+ pPolicy.build();
+ policy = GsonUtil.makeObj(pPolicy);
+ } else {
+ ThreepidPolicy pPolicy = GsonUtil.get().fromJson(GsonUtil.get().toJson(policy), ThreepidPolicy.class);
+ pPolicy.build();
+ policy = GsonUtil.makeObj(pPolicy);
+ }
+
+ getPolicy().getThreepid().put(medium, policy);
+ });
+
+ log.info("After Build");
+ log.info(GsonUtil.getPrettyForLog(this));
+ }
+
+}
diff --git a/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java b/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java
index b64613a..f98f80d 100644
--- a/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java
+++ b/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java
@@ -23,11 +23,11 @@ package io.kamax.mxisd.registration;
import com.google.gson.JsonObject;
import io.kamax.matrix.ThreePid;
import io.kamax.matrix.json.GsonUtil;
+import io.kamax.mxisd.config.RegisterConfig;
import io.kamax.mxisd.dns.ClientDnsOverwrite;
import io.kamax.mxisd.exception.NotImplementedException;
import io.kamax.mxisd.exception.RemoteHomeServerException;
import io.kamax.mxisd.invitation.InvitationManager;
-import io.kamax.mxisd.lookup.strategy.LookupStrategy;
import io.kamax.mxisd.util.RestClientUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.client.methods.CloseableHttpResponse;
@@ -40,24 +40,23 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.URI;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.Objects;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
public class RegistrationManager {
private static final Logger log = LoggerFactory.getLogger(RegistrationManager.class);
+ private final RegisterConfig cfg;
private final CloseableHttpClient client;
private final ClientDnsOverwrite dns;
- private final LookupStrategy lookup;
private final InvitationManager invMgr;
- private Map sessions = new ConcurrentHashMap<>();
-
- public RegistrationManager(CloseableHttpClient client, ClientDnsOverwrite dns, LookupStrategy lookup, InvitationManager invMgr) {
+ public RegistrationManager(RegisterConfig cfg, CloseableHttpClient client, ClientDnsOverwrite dns, InvitationManager invMgr) {
+ this.cfg = cfg;
this.client = client;
this.dns = dns;
- this.lookup = lookup;
this.invMgr = invMgr;
}
@@ -96,7 +95,48 @@ public class RegistrationManager {
}
public boolean isAllowed(ThreePid tpid) {
- return invMgr.hasInvite(tpid);
+ // We check if the policy allows registration for invites, and if there is an invite for the 3PID
+ if (cfg.getPolicy().forInvite() && invMgr.hasInvite(tpid)) {
+ log.info("Registration allowed for pending invite");
+ return true;
+ }
+
+ // The following section deals with patterns which can either be built at startup time, or for each invite at runtime.
+ // Registration is a very rare occurrence relatively speaking, so we make the choice to build the patterns each time
+ // at runtime to save on RAM.
+
+ Object policy = cfg.getPolicy().getThreepid().get(tpid.getMedium());
+ if (Objects.nonNull(policy)) {
+ RegisterConfig.ThreepidPolicy tpidPolicy = GsonUtil.get().fromJson(GsonUtil.get().toJson(policy), RegisterConfig.ThreepidPolicy.class);
+ log.info("Found registration policy for {}", tpid.getMedium());
+
+ log.info("Processing pattern blacklist");
+ for (String pattern : tpidPolicy.getPattern().getBlacklist()) {
+ log.info("Processing pattern {}", pattern);
+
+ // We compile the pattern
+ Matcher m = Pattern.compile(pattern).matcher(tpid.getAddress());
+ if (m.matches()) { // We only care about those who match...
+ log.info("Found matching blacklist entry, denying registration");
+ return false; // ... and get denied as per blacklist
+ }
+ }
+
+ log.info("Processing pattern whitelist");
+ for (String pattern : tpidPolicy.getPattern().getWhitelist()) {
+ log.info("Processing pattern {}", pattern);
+
+ // We compile the pattern
+ Matcher m = Pattern.compile(pattern).matcher(tpid.getAddress());
+ if (m.matches()) { // We only care about those who match...
+ log.info("Found matching whitelist entry, allowing registration");
+ return true; // ... and get accepted as per whitelist
+ }
+ }
+ }
+
+ log.info("Returning default registration policy: {}", cfg.getPolicy().isAllowed());
+ return cfg.getPolicy().isAllowed();
}
}