diff --git a/src/main/java/io/kamax/mxisd/Mxisd.java b/src/main/java/io/kamax/mxisd/Mxisd.java index 00022c8..2410f70 100644 --- a/src/main/java/io/kamax/mxisd/Mxisd.java +++ b/src/main/java/io/kamax/mxisd/Mxisd.java @@ -115,7 +115,7 @@ public class Mxisd { invMgr = new InvitationManager(cfg, store, idStrategy, keyMgr, signMgr, fedDns, notifMgr, pMgr); authMgr = new AuthManager(cfg, AuthProviders.get(), idStrategy, invMgr, clientDns, httpClient); dirMgr = new DirectoryManager(cfg.getDirectory(), clientDns, httpClient, DirectoryProviders.get()); - regMgr = new RegistrationManager(httpClient, clientDns, idStrategy, invMgr); + regMgr = new RegistrationManager(cfg.getRegister(), httpClient, clientDns, invMgr); asHander = new AppSvcManager(cfg, store, pMgr, notifMgr, synapse); } diff --git a/src/main/java/io/kamax/mxisd/config/MxisdConfig.java b/src/main/java/io/kamax/mxisd/config/MxisdConfig.java index 787171e..3fb5ffa 100644 --- a/src/main/java/io/kamax/mxisd/config/MxisdConfig.java +++ b/src/main/java/io/kamax/mxisd/config/MxisdConfig.java @@ -97,6 +97,7 @@ public class MxisdConfig { private MemoryStoreConfig memory = new MemoryStoreConfig(); private NotificationConfig notification = new NotificationConfig(); private NetIqLdapConfig netiq = new NetIqLdapConfig(); + private RegisterConfig register = new RegisterConfig(); private ServerConfig server = new ServerConfig(); private SessionConfig session = new SessionConfig(); private StorageConfig storage = new StorageConfig(); @@ -219,6 +220,14 @@ public class MxisdConfig { this.netiq = netiq; } + public RegisterConfig getRegister() { + return register; + } + + public void setRegister(RegisterConfig register) { + this.register = register; + } + public ServerConfig getServer() { return server; } @@ -310,6 +319,7 @@ public class MxisdConfig { getMemory().build(); getNetiq().build(); getNotification().build(); + getRegister().build(); getRest().build(); getSession().build(); getServer().build(); diff --git a/src/main/java/io/kamax/mxisd/config/RegisterConfig.java b/src/main/java/io/kamax/mxisd/config/RegisterConfig.java new file mode 100644 index 0000000..515aab4 --- /dev/null +++ b/src/main/java/io/kamax/mxisd/config/RegisterConfig.java @@ -0,0 +1,201 @@ +/* + * mxisd - Matrix Identity Server Daemon + * Copyright (C) 2019 Kamax Sarl + * + * https://www.kamax.io/ + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package io.kamax.mxisd.config; + +import io.kamax.matrix.ThreePidMedium; +import io.kamax.matrix.json.GsonUtil; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; +import java.util.stream.Collectors; + +public class RegisterConfig { + + private static final Logger log = LoggerFactory.getLogger(RegisterConfig.class); + + public static class ThreepidPolicyPattern { + + private List blacklist = new ArrayList<>(); + private List whitelist = new ArrayList<>(); + + public List getBlacklist() { + return blacklist; + } + + public void setBlacklist(List blacklist) { + this.blacklist = blacklist; + } + + public List getWhitelist() { + return whitelist; + } + + public void setWhitelist(List whitelist) { + this.whitelist = whitelist; + } + + } + + public static class EmailPolicy extends ThreepidPolicy { + + private ThreepidPolicyPattern domain = new ThreepidPolicyPattern(); + + public ThreepidPolicyPattern getDomain() { + return domain; + } + + public void setDomain(ThreepidPolicyPattern domain) { + this.domain = domain; + } + + private List buildPatterns(List domains) { + log.info("Building email policy"); + return domains.stream().map(d -> { + if (StringUtils.startsWith(d, "*")) { + log.info("Found domain and subdomain policy"); + d = "(.*)" + d.substring(1); + } else if (StringUtils.startsWith(d, ".")) { + log.info("Found subdomain-only policy"); + d = "(.*)" + d; + } else { + log.info("Found domain-only policy"); + } + + return "([^@]+)@" + d.replace(".", "\\."); + }).collect(Collectors.toList()); + } + + @Override + public void build() { + if (Objects.isNull(getDomain())) { + return; + } + + if (Objects.nonNull(getDomain().getBlacklist())) { + if (Objects.isNull(getPattern().getBlacklist())) { + getPattern().setBlacklist(new ArrayList<>()); + } + + List domains = buildPatterns(getDomain().getBlacklist()); + getPattern().getBlacklist().addAll(domains); + } + + if (Objects.nonNull(getDomain().getWhitelist())) { + if (Objects.isNull(getPattern().getWhitelist())) { + getPattern().setWhitelist(new ArrayList<>()); + } + + List domains = buildPatterns(getDomain().getWhitelist()); + getPattern().getWhitelist().addAll(domains); + } + + setDomain(null); + } + + } + + public static class ThreepidPolicy { + + private ThreepidPolicyPattern pattern = new ThreepidPolicyPattern(); + + public ThreepidPolicyPattern getPattern() { + return pattern; + } + + public void setPattern(ThreepidPolicyPattern pattern) { + this.pattern = pattern; + } + + public void build() { + // no-op + } + + } + + public static class Policy { + + private boolean allowed; + private boolean invite = true; + private Map threepid = new HashMap<>(); + + public boolean isAllowed() { + return allowed; + } + + public void setAllowed(boolean allowed) { + this.allowed = allowed; + } + + public boolean forInvite() { + return invite; + } + + public void setInvite(boolean invite) { + this.invite = invite; + } + + public Map getThreepid() { + return threepid; + } + + public void setThreepid(Map threepid) { + this.threepid = threepid; + } + + } + + private Policy policy = new Policy(); + + public Policy getPolicy() { + return policy; + } + + public void setPolicy(Policy policy) { + this.policy = policy; + } + + public void build() { + log.info("--- Registration config ---"); + + log.info("Before Build"); + log.info(GsonUtil.getPrettyForLog(this)); + + new HashMap<>(getPolicy().getThreepid()).forEach((medium, policy) -> { + if (ThreePidMedium.Email.is(medium)) { + EmailPolicy pPolicy = GsonUtil.get().fromJson(GsonUtil.get().toJson(policy), EmailPolicy.class); + pPolicy.build(); + policy = GsonUtil.makeObj(pPolicy); + } else { + ThreepidPolicy pPolicy = GsonUtil.get().fromJson(GsonUtil.get().toJson(policy), ThreepidPolicy.class); + pPolicy.build(); + policy = GsonUtil.makeObj(pPolicy); + } + + getPolicy().getThreepid().put(medium, policy); + }); + + log.info("After Build"); + log.info(GsonUtil.getPrettyForLog(this)); + } + +} diff --git a/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java b/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java index b64613a..f98f80d 100644 --- a/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java +++ b/src/main/java/io/kamax/mxisd/registration/RegistrationManager.java @@ -23,11 +23,11 @@ package io.kamax.mxisd.registration; import com.google.gson.JsonObject; import io.kamax.matrix.ThreePid; import io.kamax.matrix.json.GsonUtil; +import io.kamax.mxisd.config.RegisterConfig; import io.kamax.mxisd.dns.ClientDnsOverwrite; import io.kamax.mxisd.exception.NotImplementedException; import io.kamax.mxisd.exception.RemoteHomeServerException; import io.kamax.mxisd.invitation.InvitationManager; -import io.kamax.mxisd.lookup.strategy.LookupStrategy; import io.kamax.mxisd.util.RestClientUtils; import org.apache.commons.lang3.StringUtils; import org.apache.http.client.methods.CloseableHttpResponse; @@ -40,24 +40,23 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.URI; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; public class RegistrationManager { private static final Logger log = LoggerFactory.getLogger(RegistrationManager.class); + private final RegisterConfig cfg; private final CloseableHttpClient client; private final ClientDnsOverwrite dns; - private final LookupStrategy lookup; private final InvitationManager invMgr; - private Map sessions = new ConcurrentHashMap<>(); - - public RegistrationManager(CloseableHttpClient client, ClientDnsOverwrite dns, LookupStrategy lookup, InvitationManager invMgr) { + public RegistrationManager(RegisterConfig cfg, CloseableHttpClient client, ClientDnsOverwrite dns, InvitationManager invMgr) { + this.cfg = cfg; this.client = client; this.dns = dns; - this.lookup = lookup; this.invMgr = invMgr; } @@ -96,7 +95,48 @@ public class RegistrationManager { } public boolean isAllowed(ThreePid tpid) { - return invMgr.hasInvite(tpid); + // We check if the policy allows registration for invites, and if there is an invite for the 3PID + if (cfg.getPolicy().forInvite() && invMgr.hasInvite(tpid)) { + log.info("Registration allowed for pending invite"); + return true; + } + + // The following section deals with patterns which can either be built at startup time, or for each invite at runtime. + // Registration is a very rare occurrence relatively speaking, so we make the choice to build the patterns each time + // at runtime to save on RAM. + + Object policy = cfg.getPolicy().getThreepid().get(tpid.getMedium()); + if (Objects.nonNull(policy)) { + RegisterConfig.ThreepidPolicy tpidPolicy = GsonUtil.get().fromJson(GsonUtil.get().toJson(policy), RegisterConfig.ThreepidPolicy.class); + log.info("Found registration policy for {}", tpid.getMedium()); + + log.info("Processing pattern blacklist"); + for (String pattern : tpidPolicy.getPattern().getBlacklist()) { + log.info("Processing pattern {}", pattern); + + // We compile the pattern + Matcher m = Pattern.compile(pattern).matcher(tpid.getAddress()); + if (m.matches()) { // We only care about those who match... + log.info("Found matching blacklist entry, denying registration"); + return false; // ... and get denied as per blacklist + } + } + + log.info("Processing pattern whitelist"); + for (String pattern : tpidPolicy.getPattern().getWhitelist()) { + log.info("Processing pattern {}", pattern); + + // We compile the pattern + Matcher m = Pattern.compile(pattern).matcher(tpid.getAddress()); + if (m.matches()) { // We only care about those who match... + log.info("Found matching whitelist entry, allowing registration"); + return true; // ... and get accepted as per whitelist + } + } + } + + log.info("Returning default registration policy: {}", cfg.getPolicy().isAllowed()); + return cfg.getPolicy().isAllowed(); } }