Skeleton for invitation policies (#130)

This commit is contained in:
Max Dor
2019-02-14 23:02:55 +01:00
parent 2f7e5e4025
commit aadfae2965
13 changed files with 326 additions and 41 deletions

View File

@@ -31,6 +31,7 @@ import io.kamax.mxisd.http.undertow.handler.auth.v1.LoginHandler;
import io.kamax.mxisd.http.undertow.handler.auth.v1.LoginPostHandler; import io.kamax.mxisd.http.undertow.handler.auth.v1.LoginPostHandler;
import io.kamax.mxisd.http.undertow.handler.directory.v1.UserDirectorySearchHandler; import io.kamax.mxisd.http.undertow.handler.directory.v1.UserDirectorySearchHandler;
import io.kamax.mxisd.http.undertow.handler.identity.v1.*; import io.kamax.mxisd.http.undertow.handler.identity.v1.*;
import io.kamax.mxisd.http.undertow.handler.invite.v1.RoomInviteHandler;
import io.kamax.mxisd.http.undertow.handler.profile.v1.InternalProfileHandler; import io.kamax.mxisd.http.undertow.handler.profile.v1.InternalProfileHandler;
import io.kamax.mxisd.http.undertow.handler.profile.v1.ProfileHandler; import io.kamax.mxisd.http.undertow.handler.profile.v1.ProfileHandler;
import io.kamax.mxisd.http.undertow.handler.register.v1.Register3pidRequestTokenHandler; import io.kamax.mxisd.http.undertow.handler.register.v1.Register3pidRequestTokenHandler;
@@ -101,6 +102,9 @@ public class HttpMxisd {
// Registration endpoints // Registration endpoints
.post(Register3pidRequestTokenHandler.Path, SaneHandler.around(new Register3pidRequestTokenHandler(m.getReg(), m.getClientDns(), m.getHttpClient()))) .post(Register3pidRequestTokenHandler.Path, SaneHandler.around(new Register3pidRequestTokenHandler(m.getReg(), m.getClientDns(), m.getHttpClient())))
// Invite endpoints
.post(RoomInviteHandler.Path, SaneHandler.around(new RoomInviteHandler(m.getHttpClient(), m.getClientDns(), m.getInvitationManager())))
// Application Service endpoints // Application Service endpoints
.get("/_matrix/app/v1/users/**", asNotFoundHandler) .get("/_matrix/app/v1/users/**", asNotFoundHandler)
.get("/users/**", asNotFoundHandler) // Legacy endpoint .get("/users/**", asNotFoundHandler) // Legacy endpoint

View File

@@ -109,7 +109,7 @@ public class Mxisd {
pMgr = new ProfileManager(ProfileProviders.get(), clientDns, httpClient); pMgr = new ProfileManager(ProfileProviders.get(), clientDns, httpClient);
notifMgr = new NotificationManager(cfg.getNotification(), NotificationHandlers.get()); notifMgr = new NotificationManager(cfg.getNotification(), NotificationHandlers.get());
sessMgr = new SessionManager(cfg.getSession(), cfg.getMatrix(), store, notifMgr, idStrategy, httpClient); sessMgr = new SessionManager(cfg.getSession(), cfg.getMatrix(), store, notifMgr, idStrategy, httpClient);
invMgr = new InvitationManager(cfg, store, idStrategy, keyMgr, signMgr, fedDns, notifMgr); invMgr = new InvitationManager(cfg, store, idStrategy, keyMgr, signMgr, fedDns, notifMgr, pMgr);
authMgr = new AuthManager(cfg, AuthProviders.get(), idStrategy, invMgr, clientDns, httpClient); authMgr = new AuthManager(cfg, AuthProviders.get(), idStrategy, invMgr, clientDns, httpClient);
dirMgr = new DirectoryManager(cfg.getDirectory(), clientDns, httpClient, DirectoryProviders.get()); dirMgr = new DirectoryManager(cfg.getDirectory(), clientDns, httpClient, DirectoryProviders.get());
regMgr = new RegistrationManager(httpClient, clientDns, idStrategy, invMgr); regMgr = new RegistrationManager(httpClient, clientDns, idStrategy, invMgr);

View File

@@ -23,7 +23,9 @@ package io.kamax.mxisd.backend.sql;
import io.kamax.matrix.ThreePid; import io.kamax.matrix.ThreePid;
import io.kamax.matrix._MatrixID; import io.kamax.matrix._MatrixID;
import io.kamax.matrix._ThreePid; import io.kamax.matrix._ThreePid;
import io.kamax.mxisd.UserIdType;
import io.kamax.mxisd.config.sql.SqlConfig; import io.kamax.mxisd.config.sql.SqlConfig;
import io.kamax.mxisd.exception.InternalServerError;
import io.kamax.mxisd.profile.ProfileProvider; import io.kamax.mxisd.profile.ProfileProvider;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -33,16 +35,14 @@ import java.sql.PreparedStatement;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
public abstract class SqlProfileProvider implements ProfileProvider { public abstract class SqlProfileProvider implements ProfileProvider {
private transient final Logger log = LoggerFactory.getLogger(SqlProfileProvider.class); private static final Logger log = LoggerFactory.getLogger(SqlProfileProvider.class);
private SqlConfig.Profile cfg; private SqlConfig.Profile cfg;
private SqlConnectionPool pool; private SqlConnectionPool pool;
public SqlProfileProvider(SqlConfig cfg) { public SqlProfileProvider(SqlConfig cfg) {
@@ -50,6 +50,12 @@ public abstract class SqlProfileProvider implements ProfileProvider {
this.pool = new SqlConnectionPool(cfg); this.pool = new SqlConnectionPool(cfg);
} }
private void setParameters(PreparedStatement stmt, String value) throws SQLException {
for (int i = 1; i <= stmt.getParameterMetaData().getParameterCount(); i++) {
stmt.setString(i, value);
}
}
@Override @Override
public Optional<String> getDisplayName(_MatrixID user) { public Optional<String> getDisplayName(_MatrixID user) {
String stmtSql = cfg.getDisplayName().getQuery(); String stmtSql = cfg.getDisplayName().getQuery();
@@ -94,7 +100,33 @@ public abstract class SqlProfileProvider implements ProfileProvider {
@Override @Override
public List<String> getRoles(_MatrixID user) { public List<String> getRoles(_MatrixID user) {
return Collections.emptyList(); log.info("Querying roles for {}", user.getId());
List<String> roles = new ArrayList<>();
String stmtSql = cfg.getRole().getQuery();
try (Connection conn = pool.get()) {
PreparedStatement stmt = conn.prepareStatement(stmtSql);
if (UserIdType.Localpart.is(cfg.getRole().getType())) {
setParameters(stmt, user.getLocalPart());
} else if (UserIdType.MatrixID.is(cfg.getRole().getType())) {
setParameters(stmt, user.getId());
} else {
throw new InternalServerError("Unsupported user type in SQL Role fetching: " + cfg.getRole().getType());
}
ResultSet rSet = stmt.executeQuery();
while (rSet.next()) {
String role = rSet.getString(1);
roles.add(role);
log.debug("Found role {}", role);
}
log.info("Got {} roles", roles.size());
return roles;
} catch (SQLException e) {
throw new RuntimeException(e);
}
} }
} }

View File

@@ -43,6 +43,10 @@ public class SynapseQueries {
return "SELECT medium, address FROM user_threepids WHERE user_id = ?"; return "SELECT medium, address FROM user_threepids WHERE user_id = ?";
} }
public static String getRoles() {
return "SELECT DISTINCT(group_id) FROM group_users WHERE user_id = ?";
}
public static String findByDisplayName(String type, String domain) { public static String findByDisplayName(String type, String domain) {
if (StringUtils.equals("sqlite", type)) { if (StringUtils.equals("sqlite", type)) {
return "select " + getUserId(type, domain) + ", displayname from profiles p where displayname like ?"; return "select " + getUserId(type, domain) + ", displayname from profiles p where displayname like ?";

View File

@@ -26,9 +26,13 @@ import io.kamax.mxisd.config.MxisdConfig;
import io.kamax.mxisd.directory.DirectoryProviders; import io.kamax.mxisd.directory.DirectoryProviders;
import io.kamax.mxisd.lookup.ThreePidProviders; import io.kamax.mxisd.lookup.ThreePidProviders;
import io.kamax.mxisd.profile.ProfileProviders; import io.kamax.mxisd.profile.ProfileProviders;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SynapseSqlStoreSupplier implements IdentityStoreSupplier { public class SynapseSqlStoreSupplier implements IdentityStoreSupplier {
private static final Logger log = LoggerFactory.getLogger(SynapseSqlStoreSupplier.class);
@Override @Override
public void accept(Mxisd mxisd) { public void accept(Mxisd mxisd) {
accept(mxisd.getConfig()); accept(mxisd.getConfig());
@@ -44,6 +48,7 @@ public class SynapseSqlStoreSupplier implements IdentityStoreSupplier {
} }
if (cfg.getSynapseSql().getProfile().isEnabled()) { if (cfg.getSynapseSql().getProfile().isEnabled()) {
log.debug("Profile is enabled, registering provider");
ProfileProviders.register(() -> new SynapseSqlProfileProvider(cfg.getSynapseSql())); ProfileProviders.register(() -> new SynapseSqlProfileProvider(cfg.getSynapseSql()));
} }
} }

View File

@@ -20,13 +20,16 @@
package io.kamax.mxisd.config; package io.kamax.mxisd.config;
import io.kamax.mxisd.util.GsonUtil; import io.kamax.matrix.json.GsonUtil;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
public class InvitationConfig { public class InvitationConfig {
private transient final Logger log = LoggerFactory.getLogger(InvitationConfig.class); private static final Logger log = LoggerFactory.getLogger(InvitationConfig.class);
public static class Resolution { public static class Resolution {
@@ -51,7 +54,34 @@ public class InvitationConfig {
} }
public static class SenderPolicy {
private List<String> hasRole = new ArrayList<>();
public List<String> getHasRole() {
return hasRole;
}
public void setHasRole(List<String> hasRole) {
this.hasRole = hasRole;
}
}
public static class Policies {
private SenderPolicy ifSender = new SenderPolicy();
public SenderPolicy getIfSender() {
return ifSender;
}
public void setIfSender(SenderPolicy ifSender) {
this.ifSender = ifSender;
}
}
private Resolution resolution = new Resolution(); private Resolution resolution = new Resolution();
private Policies policy = new Policies();
public Resolution getResolution() { public Resolution getResolution() {
return resolution; return resolution;
@@ -61,9 +91,18 @@ public class InvitationConfig {
this.resolution = resolution; this.resolution = resolution;
} }
public Policies getPolicy() {
return policy;
}
public void setPolicy(Policies policy) {
this.policy = policy;
}
public void build() { public void build() {
log.info("--- Invite config ---"); log.info("--- Invite config ---");
log.info("Resolution: {}", GsonUtil.build().toJson(resolution)); log.info("Resolution: {}", GsonUtil.get().toJson(getResolution()));
log.info("Policies: {}", GsonUtil.get().toJson(getPolicy()));
} }
} }

View File

@@ -193,11 +193,35 @@ public abstract class SqlConfig {
} }
public static class ProfileRoles {
private String type;
private String query;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getQuery() {
return query;
}
public void setQuery(String query) {
this.query = query;
}
}
public static class Profile { public static class Profile {
private Boolean enabled; private Boolean enabled;
private ProfileDisplayName displayName = new ProfileDisplayName(); private ProfileDisplayName displayName = new ProfileDisplayName();
private ProfileThreepids threepid = new ProfileThreepids(); private ProfileThreepids threepid = new ProfileThreepids();
private ProfileRoles role = new ProfileRoles();
public Boolean isEnabled() { public Boolean isEnabled() {
return enabled; return enabled;
@@ -223,6 +247,14 @@ public abstract class SqlConfig {
this.threepid = threepid; this.threepid = threepid;
} }
public ProfileRoles getRole() {
return role;
}
public void setRole(ProfileRoles role) {
this.role = role;
}
} }
private boolean enabled; private boolean enabled;
@@ -323,10 +355,11 @@ public abstract class SqlConfig {
log.info("3PID mapping query: {}", getIdentity().getQuery()); log.info("3PID mapping query: {}", getIdentity().getQuery());
log.info("Identity medium queries: {}", GsonUtil.build().toJson(getIdentity().getMedium())); log.info("Identity medium queries: {}", GsonUtil.build().toJson(getIdentity().getMedium()));
log.info("Profile:"); log.info("Profile:");
log.info("\tEnabled: {}", getProfile().isEnabled()); log.info(" Enabled: {}", getProfile().isEnabled());
if (getProfile().isEnabled()) { if (getProfile().isEnabled()) {
log.info("\tDisplay name query: {}", getProfile().getDisplayName().getQuery()); log.info(" Display name query: {}", getProfile().getDisplayName().getQuery());
log.info("\tProfile 3PID query: {}", getProfile().getThreepid().getQuery()); log.info(" Profile 3PID query: {}", getProfile().getThreepid().getQuery());
log.info(" Role query: {}", getProfile().getRole().getQuery());
} }
} }
} }

View File

@@ -20,6 +20,7 @@
package io.kamax.mxisd.config.sql.synapse; package io.kamax.mxisd.config.sql.synapse;
import io.kamax.mxisd.UserIdType;
import io.kamax.mxisd.backend.sql.synapse.SynapseQueries; import io.kamax.mxisd.backend.sql.synapse.SynapseQueries;
import io.kamax.mxisd.config.sql.SqlConfig; import io.kamax.mxisd.config.sql.SqlConfig;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.StringUtils;
@@ -48,9 +49,17 @@ public class SynapseSqlProviderConfig extends SqlConfig {
if (StringUtils.isBlank(getProfile().getDisplayName().getQuery())) { if (StringUtils.isBlank(getProfile().getDisplayName().getQuery())) {
getProfile().getDisplayName().setQuery(SynapseQueries.getDisplayName()); getProfile().getDisplayName().setQuery(SynapseQueries.getDisplayName());
} }
if (StringUtils.isBlank(getProfile().getThreepid().getQuery())) { if (StringUtils.isBlank(getProfile().getThreepid().getQuery())) {
getProfile().getThreepid().setQuery(SynapseQueries.getThreepids()); getProfile().getThreepid().setQuery(SynapseQueries.getThreepids());
} }
if (StringUtils.isBlank(getProfile().getRole().getType())) {
getProfile().getRole().setType(UserIdType.MatrixID.getId());
}
if (StringUtils.isBlank(getProfile().getRole().getQuery())) {
getProfile().getRole().setQuery(SynapseQueries.getRoles());
}
} }
printConfig(); printConfig();

View File

@@ -23,20 +23,29 @@ package io.kamax.mxisd.http.undertow.handler;
import com.google.gson.JsonElement; import com.google.gson.JsonElement;
import com.google.gson.JsonObject; import com.google.gson.JsonObject;
import io.kamax.matrix.json.GsonUtil; import io.kamax.matrix.json.GsonUtil;
import io.kamax.mxisd.dns.ClientDnsOverwrite;
import io.kamax.mxisd.exception.AccessTokenNotFoundException;
import io.kamax.mxisd.exception.HttpMatrixException; import io.kamax.mxisd.exception.HttpMatrixException;
import io.kamax.mxisd.exception.InternalServerError; import io.kamax.mxisd.exception.InternalServerError;
import io.kamax.mxisd.proxy.Response; import io.kamax.mxisd.proxy.Response;
import io.kamax.mxisd.util.RestClientUtils;
import io.undertow.server.HttpHandler; import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import io.undertow.util.HttpString; import io.undertow.util.HttpString;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.http.Header;
import org.apache.http.HeaderElement;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URLDecoder; import java.net.URLDecoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Deque; import java.util.Deque;
@@ -46,7 +55,19 @@ import java.util.Optional;
public abstract class BasicHttpHandler implements HttpHandler { public abstract class BasicHttpHandler implements HttpHandler {
private transient final Logger log = LoggerFactory.getLogger(BasicHttpHandler.class); private static final Logger log = LoggerFactory.getLogger(BasicHttpHandler.class);
protected String getAccessToken(HttpServerExchange exchange) {
return Optional.ofNullable(exchange.getRequestHeaders().getFirst("Authorization"))
.flatMap(v -> {
if (!v.startsWith("Bearer ")) {
return Optional.empty();
}
return Optional.of(v.substring("Bearer ".length()));
}).filter(StringUtils::isNotEmpty)
.orElseThrow(AccessTokenNotFoundException::new);
}
protected String getRemoteHostAddress(HttpServerExchange exchange) { protected String getRemoteHostAddress(HttpServerExchange exchange) {
return ((InetSocketAddress) exchange.getConnection().getPeerAddress()).getAddress().getHostAddress(); return ((InetSocketAddress) exchange.getConnection().getPeerAddress()).getAddress().getHostAddress();
@@ -149,4 +170,34 @@ public abstract class BasicHttpHandler implements HttpHandler {
upstream.getHeaders().forEach((key, value) -> exchange.getResponseHeaders().addAll(HttpString.tryFromString(key), value)); upstream.getHeaders().forEach((key, value) -> exchange.getResponseHeaders().addAll(HttpString.tryFromString(key), value));
writeBodyAsUtf8(exchange, upstream.getBody()); writeBodyAsUtf8(exchange, upstream.getBody());
} }
protected void proxyPost(HttpServerExchange exchange, JsonObject body, CloseableHttpClient client, ClientDnsOverwrite dns) {
String target = dns.transform(URI.create(exchange.getRequestURL())).toString();
log.info("Requesting remote: {}", target);
HttpPost req = RestClientUtils.post(target, GsonUtil.get(), body);
exchange.getRequestHeaders().forEach(header -> {
header.forEach(v -> {
String name = header.getHeaderName().toString();
if (!StringUtils.startsWithIgnoreCase(name, "content-")) {
req.addHeader(name, v);
}
});
});
try (CloseableHttpResponse res = client.execute(req)) {
exchange.setStatusCode(res.getStatusLine().getStatusCode());
for (Header h : res.getAllHeaders()) {
for (HeaderElement el : h.getElements()) {
exchange.getResponseHeaders().add(HttpString.tryFromString(h.getName()), el.getValue());
}
}
res.getEntity().writeTo(exchange.getOutputStream());
exchange.endExchange();
} catch (IOException e) {
log.warn("Unable to make proxy call: {}", e.getMessage(), e);
throw new InternalServerError(e);
}
}
} }

View File

@@ -0,0 +1,103 @@
/*
* mxisd - Matrix Identity Server Daemon
* Copyright (C) 2019 Kamax Sarl
*
* https://www.kamax.io/
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package io.kamax.mxisd.http.undertow.handler.invite.v1;
import com.google.gson.JsonObject;
import io.kamax.matrix.MatrixID;
import io.kamax.matrix._MatrixID;
import io.kamax.matrix.json.GsonUtil;
import io.kamax.mxisd.dns.ClientDnsOverwrite;
import io.kamax.mxisd.exception.InternalServerError;
import io.kamax.mxisd.exception.NotAllowedException;
import io.kamax.mxisd.exception.RemoteHomeServerException;
import io.kamax.mxisd.http.undertow.handler.BasicHttpHandler;
import io.kamax.mxisd.invitation.InvitationManager;
import io.undertow.server.HttpServerExchange;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.URI;
import java.util.Optional;
public class RoomInviteHandler extends BasicHttpHandler {
public static final String Path = "/_matrix/client/r0/rooms/{roomId}/invite";
private static final Logger log = LoggerFactory.getLogger(RoomInviteHandler.class);
private final CloseableHttpClient client;
private final ClientDnsOverwrite dns;
private final InvitationManager invMgr;
public RoomInviteHandler(CloseableHttpClient client, ClientDnsOverwrite dns, InvitationManager invMgr) {
this.client = client;
this.dns = dns;
this.invMgr = invMgr;
}
@Override
public void handleRequest(HttpServerExchange exchange) {
String accessToken = getAccessToken(exchange);
String whoamiUri = dns.transform(URI.create(exchange.getRequestURL()).resolve(URI.create("/_matrix/client/r0/account/whoami"))).toString();
log.info("Who Am I URL: {}", whoamiUri);
HttpGet whoAmIReq = new HttpGet(whoamiUri);
whoAmIReq.addHeader("Authorization", "Bearer " + accessToken);
_MatrixID uId;
try (CloseableHttpResponse whoAmIRes = client.execute(whoAmIReq)) {
int sc = whoAmIRes.getStatusLine().getStatusCode();
String body = EntityUtils.toString(whoAmIRes.getEntity());
if (sc != 200) {
log.warn("Unable to get caller identity from Homeserver - Status code: {}", sc);
log.debug("Body: {}", body);
throw new RemoteHomeServerException(body);
}
JsonObject json = GsonUtil.parseObj(body);
Optional<String> uIdRaw = GsonUtil.findString(json, "user_id");
if (!uIdRaw.isPresent()) {
throw new RemoteHomeServerException("No User ID provided when checking identity");
}
uId = MatrixID.asAcceptable(uIdRaw.get());
} catch (IOException e) {
InternalServerError ex = new InternalServerError(e);
log.error("Ref {}: Unable to fetch caller identity from Homeserver", ex.getReference());
throw ex;
}
log.info("Processing room invite from {}", uId.getId());
JsonObject reqBody = parseJsonObject(exchange);
if (!invMgr.canInvite(uId, reqBody)) {
throw new NotAllowedException("Your account is not allowed to invite that address");
}
log.info("Invite was allowing, relaying to the Homeserver");
proxyPost(exchange, reqBody, client, dns);
}
}

View File

@@ -25,26 +25,16 @@ import io.kamax.matrix.ThreePid;
import io.kamax.matrix.ThreePidMedium; import io.kamax.matrix.ThreePidMedium;
import io.kamax.matrix.json.GsonUtil; import io.kamax.matrix.json.GsonUtil;
import io.kamax.mxisd.dns.ClientDnsOverwrite; import io.kamax.mxisd.dns.ClientDnsOverwrite;
import io.kamax.mxisd.exception.InternalServerError;
import io.kamax.mxisd.exception.NotAllowedException; import io.kamax.mxisd.exception.NotAllowedException;
import io.kamax.mxisd.http.io.identity.SessionEmailTokenRequestJson; import io.kamax.mxisd.http.io.identity.SessionEmailTokenRequestJson;
import io.kamax.mxisd.http.io.identity.SessionPhoneTokenRequestJson; import io.kamax.mxisd.http.io.identity.SessionPhoneTokenRequestJson;
import io.kamax.mxisd.http.undertow.handler.BasicHttpHandler; import io.kamax.mxisd.http.undertow.handler.BasicHttpHandler;
import io.kamax.mxisd.registration.RegistrationManager; import io.kamax.mxisd.registration.RegistrationManager;
import io.kamax.mxisd.util.RestClientUtils;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import io.undertow.util.HttpString;
import org.apache.http.Header;
import org.apache.http.HeaderElement;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.CloseableHttpClient;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.URI;
public class Register3pidRequestTokenHandler extends BasicHttpHandler { public class Register3pidRequestTokenHandler extends BasicHttpHandler {
public static final String Key = "medium"; public static final String Key = "medium";
@@ -77,25 +67,11 @@ public class Register3pidRequestTokenHandler extends BasicHttpHandler {
} }
ThreePid tpid = new ThreePid(medium, address); ThreePid tpid = new ThreePid(medium, address);
if (!mgr.allow(tpid)) { if (!mgr.isAllowed(tpid)) {
throw new NotAllowedException("Your " + medium + " address cannot be used for registration"); throw new NotAllowedException("Your " + medium + " address cannot be used for registration");
} }
String target = dns.transform(URI.create(exchange.getRequestURL())).toString(); proxyPost(exchange, body, client, dns);
log.info("Requesting remote: {}", target);
HttpPost req = RestClientUtils.post(target, GsonUtil.get(), body);
try (CloseableHttpResponse res = client.execute(req)) {
exchange.setStatusCode(res.getStatusLine().getStatusCode());
for (Header h : res.getAllHeaders()) {
for (HeaderElement el : h.getElements()) {
exchange.getResponseHeaders().add(HttpString.tryFromString(h.getName()), el.getValue());
}
}
res.getEntity().writeTo(exchange.getOutputStream());
exchange.endExchange();
} catch (IOException e) {
throw new InternalServerError(e);
}
} }
} }

View File

@@ -24,6 +24,7 @@ import com.google.gson.JsonArray;
import com.google.gson.JsonObject; import com.google.gson.JsonObject;
import io.kamax.matrix.MatrixID; import io.kamax.matrix.MatrixID;
import io.kamax.matrix.ThreePid; import io.kamax.matrix.ThreePid;
import io.kamax.matrix._MatrixID;
import io.kamax.matrix.json.GsonUtil; import io.kamax.matrix.json.GsonUtil;
import io.kamax.mxisd.config.InvitationConfig; import io.kamax.mxisd.config.InvitationConfig;
import io.kamax.mxisd.config.MxisdConfig; import io.kamax.mxisd.config.MxisdConfig;
@@ -36,6 +37,7 @@ import io.kamax.mxisd.lookup.SingleLookupReply;
import io.kamax.mxisd.lookup.ThreePidMapping; import io.kamax.mxisd.lookup.ThreePidMapping;
import io.kamax.mxisd.lookup.strategy.LookupStrategy; import io.kamax.mxisd.lookup.strategy.LookupStrategy;
import io.kamax.mxisd.notification.NotificationManager; import io.kamax.mxisd.notification.NotificationManager;
import io.kamax.mxisd.profile.ProfileManager;
import io.kamax.mxisd.storage.IStorage; import io.kamax.mxisd.storage.IStorage;
import io.kamax.mxisd.storage.crypto.*; import io.kamax.mxisd.storage.crypto.*;
import io.kamax.mxisd.storage.ormlite.dao.ThreePidInviteIO; import io.kamax.mxisd.storage.ormlite.dao.ThreePidInviteIO;
@@ -78,6 +80,7 @@ public class InvitationManager {
private SignatureManager signMgr; private SignatureManager signMgr;
private FederationDnsOverwrite dns; private FederationDnsOverwrite dns;
private NotificationManager notifMgr; private NotificationManager notifMgr;
private ProfileManager profileMgr;
private CloseableHttpClient client; private CloseableHttpClient client;
private Timer refreshTimer; private Timer refreshTimer;
@@ -91,7 +94,8 @@ public class InvitationManager {
KeyManager keyMgr, KeyManager keyMgr,
SignatureManager signMgr, SignatureManager signMgr,
FederationDnsOverwrite dns, FederationDnsOverwrite dns,
NotificationManager notifMgr NotificationManager notifMgr,
ProfileManager profileMgr
) { ) {
this.cfg = mxisdCfg.getInvite(); this.cfg = mxisdCfg.getInvite();
this.srvCfg = mxisdCfg.getServer(); this.srvCfg = mxisdCfg.getServer();
@@ -101,6 +105,7 @@ public class InvitationManager {
this.signMgr = signMgr; this.signMgr = signMgr;
this.dns = dns; this.dns = dns;
this.notifMgr = notifMgr; this.notifMgr = notifMgr;
this.profileMgr = profileMgr;
log.info("Loading saved invites"); log.info("Loading saved invites");
Collection<ThreePidInviteIO> ioList = storage.getInvites(); Collection<ThreePidInviteIO> ioList = storage.getInvites();
@@ -214,6 +219,30 @@ public class InvitationManager {
return lookupMgr.find(medium, address, cfg.getResolution().isRecursive()); return lookupMgr.find(medium, address, cfg.getResolution().isRecursive());
} }
public boolean canInvite(_MatrixID sender, JsonObject request) {
if (!request.has("medium")) {
log.info("Not a 3PID invite, allowing");
return true;
}
log.info("3PID invite detected, checking policies...");
List<String> allowedRoles = cfg.getPolicy().getIfSender().getHasRole();
if (Objects.isNull(allowedRoles)) {
log.info("No allowed role configured for sender, allowing");
return true;
}
List<String> userRoles = profileMgr.getRoles(sender);
if (Collections.disjoint(userRoles, allowedRoles)) {
log.info("Sender does not have any of the required roles, denying");
return false;
}
log.info("Sender has at least one of the required roles");
log.info("Sender pass all policies to invite, allowing");
return true;
}
public synchronized IThreePidInviteReply storeInvite(IThreePidInvite invitation) { // TODO better sync public synchronized IThreePidInviteReply storeInvite(IThreePidInvite invitation) { // TODO better sync
if (!notifMgr.isMediumSupported(invitation.getMedium())) { if (!notifMgr.isMediumSupported(invitation.getMedium())) {
throw new BadRequestException("Medium type " + invitation.getMedium() + " is not supported"); throw new BadRequestException("Medium type " + invitation.getMedium() + " is not supported");

View File

@@ -95,7 +95,7 @@ public class RegistrationManager {
} }
} }
public boolean allow(ThreePid tpid) { public boolean isAllowed(ThreePid tpid) {
return invMgr.hasInvite(tpid); return invMgr.hasInvite(tpid);
} }