diff --git a/src/main/java/io/kamax/mxisd/HttpMxisd.java b/src/main/java/io/kamax/mxisd/HttpMxisd.java index c7aec3f..0911c7c 100644 --- a/src/main/java/io/kamax/mxisd/HttpMxisd.java +++ b/src/main/java/io/kamax/mxisd/HttpMxisd.java @@ -22,8 +22,10 @@ package io.kamax.mxisd; import io.kamax.mxisd.config.MatrixConfig; import io.kamax.mxisd.config.MxisdConfig; +import io.kamax.mxisd.config.PolicyConfig; import io.kamax.mxisd.http.undertow.handler.ApiHandler; import io.kamax.mxisd.http.undertow.handler.AuthorizationHandler; +import io.kamax.mxisd.http.undertow.handler.CheckTermsHandler; import io.kamax.mxisd.http.undertow.handler.InternalInfoHandler; import io.kamax.mxisd.http.undertow.handler.OptionsHandler; import io.kamax.mxisd.http.undertow.handler.SaneHandler; @@ -65,7 +67,10 @@ import io.undertow.server.RoutingHandler; import io.undertow.util.HttpString; import io.undertow.util.Methods; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import java.util.regex.Pattern; public class HttpMxisd { @@ -187,13 +192,26 @@ public class HttpMxisd { } } - private void attachHandler(RoutingHandler routingHandler, HttpString method, ApiHandler apiHandler, boolean useAuthorization, HttpHandler httpHandler) { + private void attachHandler(RoutingHandler routingHandler, HttpString method, ApiHandler apiHandler, boolean useAuthorization, + HttpHandler httpHandler) { MatrixConfig matrixConfig = m.getConfig().getMatrix(); if (matrixConfig.isV1()) { routingHandler.add(method, apiHandler.getPath(IdentityServiceAPI.V1), httpHandler); } if (matrixConfig.isV2()) { - HttpHandler wrappedHandler = useAuthorization ? AuthorizationHandler.around(m.getAccMgr(), httpHandler) : httpHandler; + PolicyConfig policyConfig = m.getConfig().getPolicy(); + List policies = new ArrayList<>(); + if (!policyConfig.getPolicies().isEmpty()) { + for (PolicyConfig.PolicyObject policy : policyConfig.getPolicies().values()) { + for (Pattern pattern : policy.getPatterns()) { + if (pattern.matcher(apiHandler.getHandlerPath()).matches()) { + policies.add(policy); + } + } + } + } + HttpHandler handlerWithTerms = CheckTermsHandler.around(m.getAccMgr(), httpHandler, policies); + HttpHandler wrappedHandler = useAuthorization ? AuthorizationHandler.around(m.getAccMgr(), handlerWithTerms) : handlerWithTerms; routingHandler.add(method, apiHandler.getPath(IdentityServiceAPI.V2), wrappedHandler); } } diff --git a/src/main/java/io/kamax/mxisd/auth/AccountManager.java b/src/main/java/io/kamax/mxisd/auth/AccountManager.java index 0fc44a0..8715730 100644 --- a/src/main/java/io/kamax/mxisd/auth/AccountManager.java +++ b/src/main/java/io/kamax/mxisd/auth/AccountManager.java @@ -5,6 +5,7 @@ import io.kamax.matrix.MatrixID; import io.kamax.matrix.json.GsonUtil; import io.kamax.mxisd.config.AccountConfig; import io.kamax.mxisd.config.MatrixConfig; +import io.kamax.mxisd.config.PolicyConfig; import io.kamax.mxisd.exception.BadRequestException; import io.kamax.mxisd.exception.InvalidCredentialsException; import io.kamax.mxisd.exception.NotFoundException; @@ -22,6 +23,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.time.Instant; +import java.util.List; import java.util.Objects; import java.util.UUID; @@ -139,6 +141,14 @@ public class AccountManager { storage.deleteToken(token); } + public void acceptTerm(String token, String url) { + storage.acceptTerm(token, url); + } + + public boolean isTermAccepted(String token, List policies) { + return policies.isEmpty() || storage.isTermAccepted(token, policies); + } + public AccountConfig getAccountConfig() { return accountConfig; } diff --git a/src/main/java/io/kamax/mxisd/config/AcceptingPolicy.java b/src/main/java/io/kamax/mxisd/config/AcceptingPolicy.java new file mode 100644 index 0000000..381d983 --- /dev/null +++ b/src/main/java/io/kamax/mxisd/config/AcceptingPolicy.java @@ -0,0 +1,8 @@ +package io.kamax.mxisd.config; + +public enum AcceptingPolicy { + + ALL, + + ANY +} diff --git a/src/main/java/io/kamax/mxisd/config/PolicyConfig.java b/src/main/java/io/kamax/mxisd/config/PolicyConfig.java index 6c23e2f..a0bba22 100644 --- a/src/main/java/io/kamax/mxisd/config/PolicyConfig.java +++ b/src/main/java/io/kamax/mxisd/config/PolicyConfig.java @@ -3,22 +3,25 @@ package io.kamax.mxisd.config; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.regex.Pattern; public class PolicyConfig { private static final Logger LOGGER = LoggerFactory.getLogger(PolicyConfig.class); - public static class PolicyObject { + private Map policies = new HashMap<>(); + + private AcceptingPolicy acceptingPolicy = AcceptingPolicy.ANY; + + public static class TermObject { private String name; - private String version; - - private Map urls; - - private boolean required = true; + private String url; public String getName() { return name; @@ -28,6 +31,25 @@ public class PolicyConfig { this.name = name; } + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + } + + public static class PolicyObject { + + private String version; + + private Map terms; + + private List regexp = new ArrayList<>(); + + private transient List patterns = new ArrayList<>(); + public String getVersion() { return version; } @@ -36,25 +58,27 @@ public class PolicyConfig { this.version = version; } - public Map getUrls() { - return urls; + public Map getTerms() { + return terms; } - public void setUrls(Map urls) { - this.urls = urls; + public void setTerms(Map terms) { + this.terms = terms; } - public boolean isRequired() { - return required; + public List getRegexp() { + return regexp; } - public void setRequired(boolean required) { - this.required = required; + public void setRegexp(List regexp) { + this.regexp = regexp; + } + + public List getPatterns() { + return patterns; } } - private Map policies = new HashMap<>(); - public Map getPolicies() { return policies; } @@ -63,21 +87,33 @@ public class PolicyConfig { this.policies = policies; } + public AcceptingPolicy getAcceptingPolicy() { + return acceptingPolicy; + } + + public void setAcceptingPolicy(AcceptingPolicy acceptingPolicy) { + this.acceptingPolicy = acceptingPolicy; + } + public void build() { LOGGER.info("--- Policy Config ---"); if (getPolicies().isEmpty()) { LOGGER.info("Empty"); } else { - for (Map.Entry policyObjectEntry : getPolicies().entrySet()) { - PolicyObject policyObject = policyObjectEntry.getValue(); + for (Map.Entry policyObjectItem : getPolicies().entrySet()) { + PolicyObject policyObject = policyObjectItem.getValue(); StringBuilder sb = new StringBuilder(); - sb.append("Policy \"").append(policyObjectEntry.getKey()).append("\"\n"); + sb.append("Policy \"").append(policyObjectItem.getKey()).append("\"\n"); sb.append(" version: ").append(policyObject.getVersion()).append("\n"); - sb.append(" required: ").append(policyObject.isRequired()).append("\n"); - sb.append(" urls:\n"); - for (Map.Entry urlEntry : policyObject.getUrls().entrySet()) { - sb.append(" lang: ").append(urlEntry.getKey()).append("\n"); - sb.append(" url: ").append(urlEntry.getValue()); + for (String regexp : policyObjectItem.getValue().getRegexp()) { + sb.append(" - ").append(regexp).append("\n"); + policyObjectItem.getValue().getPatterns().add(Pattern.compile(regexp)); + } + sb.append(" terms:\n"); + for (Map.Entry termItem : policyObject.getTerms().entrySet()) { + sb.append(" - lang: ").append(termItem.getKey()).append("\n"); + sb.append(" name: ").append(termItem.getValue().getName()).append("\n"); + sb.append(" url: ").append(termItem.getValue().getUrl()).append("\n"); } LOGGER.info(sb.toString()); } diff --git a/src/main/java/io/kamax/mxisd/http/undertow/handler/CheckTermsHandler.java b/src/main/java/io/kamax/mxisd/http/undertow/handler/CheckTermsHandler.java new file mode 100644 index 0000000..c30351f --- /dev/null +++ b/src/main/java/io/kamax/mxisd/http/undertow/handler/CheckTermsHandler.java @@ -0,0 +1,70 @@ +/* + * mxisd - Matrix Identity Server Daemon + * Copyright (C) 2018 Kamax Sarl + * + * https://www.kamax.io/ + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package io.kamax.mxisd.http.undertow.handler; + +import io.kamax.mxisd.auth.AccountManager; +import io.kamax.mxisd.config.PolicyConfig; +import io.kamax.mxisd.exception.InvalidCredentialsException; +import io.kamax.mxisd.storage.ormlite.dao.AccountDao; +import io.undertow.server.HttpHandler; +import io.undertow.server.HttpServerExchange; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class CheckTermsHandler extends BasicHttpHandler { + + private static final Logger log = LoggerFactory.getLogger(CheckTermsHandler.class); + + private final AccountManager accountManager; + + private final HttpHandler child; + + private final List policies; + + public static CheckTermsHandler around(AccountManager accountManager, HttpHandler child, List policies) { + return new CheckTermsHandler(accountManager, child, policies); + } + + private CheckTermsHandler(AccountManager accountManager, HttpHandler child, + List policies) { + this.accountManager = accountManager; + this.child = child; + this.policies = policies; + } + + @Override + public void handleRequest(HttpServerExchange exchange) throws Exception { + String token = findAccessToken(exchange).orElse(null); + if (token == null) { + log.error("Unauthorized request from: {}", exchange.getHostAndPort()); + throw new InvalidCredentialsException(); + } + + if (!accountManager.isTermAccepted(token, policies)) { + log.error("Non accepting request from: {}", exchange.getHostAndPort()); + throw new InvalidCredentialsException(); + } + log.trace("Access granted"); + child.handleRequest(exchange); + } +} diff --git a/src/main/java/io/kamax/mxisd/storage/IStorage.java b/src/main/java/io/kamax/mxisd/storage/IStorage.java index eebf9f9..a92c52a 100644 --- a/src/main/java/io/kamax/mxisd/storage/IStorage.java +++ b/src/main/java/io/kamax/mxisd/storage/IStorage.java @@ -21,6 +21,7 @@ package io.kamax.mxisd.storage; import io.kamax.matrix.ThreePid; +import io.kamax.mxisd.config.PolicyConfig; import io.kamax.mxisd.invitation.IThreePidInviteReply; import io.kamax.mxisd.storage.dao.IThreePidSessionDao; import io.kamax.mxisd.storage.ormlite.dao.ASTransactionDao; @@ -29,6 +30,7 @@ import io.kamax.mxisd.storage.ormlite.dao.ThreePidInviteIO; import java.time.Instant; import java.util.Collection; +import java.util.List; import java.util.Optional; public interface IStorage { @@ -57,5 +59,9 @@ public interface IStorage { Optional findAccount(String token); - void deleteToken(String accessToken); + void deleteToken(String token); + + void acceptTerm(String token, String url); + + boolean isTermAccepted(String token, List policies); } diff --git a/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java b/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java index 421ef28..ff0b2ae 100644 --- a/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java +++ b/src/main/java/io/kamax/mxisd/storage/ormlite/OrmLiteSqlStorage.java @@ -28,14 +28,17 @@ import com.j256.ormlite.support.ConnectionSource; import com.j256.ormlite.table.TableUtils; import io.kamax.matrix.ThreePid; import io.kamax.mxisd.config.MxisdConfig; +import io.kamax.mxisd.config.PolicyConfig; import io.kamax.mxisd.exception.ConfigurationException; import io.kamax.mxisd.exception.InternalServerError; +import io.kamax.mxisd.exception.InvalidCredentialsException; import io.kamax.mxisd.invitation.IThreePidInviteReply; import io.kamax.mxisd.storage.IStorage; import io.kamax.mxisd.storage.dao.IThreePidSessionDao; import io.kamax.mxisd.storage.ormlite.dao.ASTransactionDao; import io.kamax.mxisd.storage.ormlite.dao.AccountDao; import io.kamax.mxisd.storage.ormlite.dao.HistoricalThreePidInviteIO; +import io.kamax.mxisd.storage.ormlite.dao.AcceptedDao; import io.kamax.mxisd.storage.ormlite.dao.ThreePidInviteIO; import io.kamax.mxisd.storage.ormlite.dao.ThreePidSessionDao; import org.apache.commons.lang.StringUtils; @@ -70,6 +73,7 @@ public class OrmLiteSqlStorage implements IStorage { private Dao sessionDao; private Dao asTxnDao; private Dao accountDao; + private Dao acceptedDao; public OrmLiteSqlStorage(MxisdConfig cfg) { this(cfg.getStorage().getBackend(), cfg.getStorage().getProvider().getSqlite().getDatabase()); @@ -91,6 +95,7 @@ public class OrmLiteSqlStorage implements IStorage { sessionDao = createDaoAndTable(connPool, ThreePidSessionDao.class); asTxnDao = createDaoAndTable(connPool, ASTransactionDao.class); accountDao = createDaoAndTable(connPool, AccountDao.class); + acceptedDao = createDaoAndTable(connPool, AcceptedDao.class); }); } @@ -277,4 +282,33 @@ public class OrmLiteSqlStorage implements IStorage { } }); } + + @Override + public void acceptTerm(String token, String url) { + withCatcher(() -> { + AccountDao account = findAccount(token).orElseThrow(InvalidCredentialsException::new); + int created = acceptedDao.create(new AcceptedDao(url, account.getUserId(), System.currentTimeMillis())); + if (created != 1) { + throw new RuntimeException("Unexpected row count after DB action: " + created); + } + }); + } + + @Override + public boolean isTermAccepted(String token, List policies) { + return withCatcher(() -> { + AccountDao account = findAccount(token).orElseThrow(InvalidCredentialsException::new); + List acceptedTerms = acceptedDao.queryForEq("userId", account.getUserId()); + for (AcceptedDao acceptedTerm : acceptedTerms) { + for (PolicyConfig.PolicyObject policy : policies) { + for (PolicyConfig.TermObject termObject : policy.getTerms().values()) { + if (termObject.getUrl().equalsIgnoreCase(acceptedTerm.getUrl())) { + return true; + } + } + } + } + return false; + }); + } } diff --git a/src/main/java/io/kamax/mxisd/storage/ormlite/dao/AcceptedDao.java b/src/main/java/io/kamax/mxisd/storage/ormlite/dao/AcceptedDao.java new file mode 100644 index 0000000..eb8c907 --- /dev/null +++ b/src/main/java/io/kamax/mxisd/storage/ormlite/dao/AcceptedDao.java @@ -0,0 +1,71 @@ +/* + * mxisd - Matrix Identity Server Daemon + * Copyright (C) 2018 Kamax Sarl + * + * https://www.kamax.io/ + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package io.kamax.mxisd.storage.ormlite.dao; + +import com.j256.ormlite.field.DatabaseField; +import com.j256.ormlite.table.DatabaseTable; + +@DatabaseTable(tableName = "accepted") +public class AcceptedDao { + + @DatabaseField(canBeNull = false, id = true) + private String url; + + @DatabaseField(canBeNull = false) + private String userId; + + @DatabaseField(canBeNull = false) + private long acceptedAt; + + public AcceptedDao() { + // Needed for ORMLite + } + + public AcceptedDao(String url, String userId, long acceptedAt) { + this.url = url; + this.userId = userId; + this.acceptedAt = acceptedAt; + } + + public String getUrl() { + return url; + } + + public void setUrl(String url) { + this.url = url; + } + + public String getUserId() { + return userId; + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public long getAcceptedAt() { + return acceptedAt; + } + + public void setAcceptedAt(long acceptedAt) { + this.acceptedAt = acceptedAt; + } +}