diff --git a/src/main/groovy/io/kamax/mxisd/lookup/provider/DnsLookupProvider.groovy b/src/main/groovy/io/kamax/mxisd/lookup/provider/DnsLookupProvider.groovy index 3d0b328..d8033f7 100644 --- a/src/main/groovy/io/kamax/mxisd/lookup/provider/DnsLookupProvider.groovy +++ b/src/main/groovy/io/kamax/mxisd/lookup/provider/DnsLookupProvider.groovy @@ -33,6 +33,8 @@ import org.xbill.DNS.Lookup import org.xbill.DNS.SRVRecord import org.xbill.DNS.Type +import java.util.concurrent.ForkJoinPool +import java.util.concurrent.RecursiveTask import java.util.function.Function @Component @@ -129,7 +131,6 @@ class DnsLookupProvider extends RemoteIdentityServerProvider { @Override List populate(List mappings) { - List mappingsFound = new ArrayList<>() Map> domains = new HashMap<>() for (ThreePidMapping mapping : mappings) { @@ -157,20 +158,59 @@ class DnsLookupProvider extends RemoteIdentityServerProvider { } log.info("Looking mappings across {} domains", domains.keySet().size()) - for (String domain : domains.keySet()) { - Optional baseUrl = findIdentityServerForDomain(domain) - if (!baseUrl.isPresent()) { - log.info("No usable Identity server for domain {}", domain) - continue + ForkJoinPool pool = new ForkJoinPool() + RecursiveTask> task = new RecursiveTask>() { + + @Override + protected List compute() { + List mappingsFound = new ArrayList<>() + List tasks = new ArrayList<>() + + for (String domain : domains.keySet()) { + DomainBulkLookupTask domainTask = new DomainBulkLookupTask(domain, domains.get(domain)) + domainTask.fork() + tasks.add(domainTask) + } + + for (DomainBulkLookupTask task : tasks) { + mappingsFound.addAll(task.join()) + } + + return mappingsFound } - - List domainMappings = find(baseUrl.get(), domains.get(domain)) - log.info("Found {} mappings in domain {}", domainMappings.size(), domain) - mappingsFound.addAll(domainMappings) } + pool.submit(task) + pool.shutdown() + List mappingsFound = task.join() log.info("Found {} mappings overall", mappingsFound.size()) return mappingsFound } + private class DomainBulkLookupTask extends RecursiveTask> { + + private String domain + private List mappings + + DomainBulkLookupTask(String domain, List mappings) { + this.domain = domain + this.mappings = mappings + } + + @Override + protected List compute() { + List domainMappings = new ArrayList<>() + + Optional baseUrl = findIdentityServerForDomain(domain) + if (!baseUrl.isPresent()) { + log.info("No usable Identity server for domain {}", domain) + } else { + domainMappings.addAll(find(baseUrl.get(), mappings)) + log.info("Found {} mappings in domain {}", domainMappings.size(), domain) + } + + return domainMappings + } + } + } diff --git a/src/main/groovy/io/kamax/mxisd/lookup/provider/RemoteIdentityServerProvider.groovy b/src/main/groovy/io/kamax/mxisd/lookup/provider/RemoteIdentityServerProvider.groovy index 85985da..abcee20 100644 --- a/src/main/groovy/io/kamax/mxisd/lookup/provider/RemoteIdentityServerProvider.groovy +++ b/src/main/groovy/io/kamax/mxisd/lookup/provider/RemoteIdentityServerProvider.groovy @@ -55,6 +55,8 @@ abstract class RemoteIdentityServerProvider implements ThreePidProvider { HttpURLConnection rootSrvConn = (HttpURLConnection) new URL( "${remote}/_matrix/identity/api/v1/lookup?medium=${THREEPID_TEST_MEDIUM}&address=${THREEPID_TEST_ADDRESS}" ).openConnection() + // TODO turn this into a configuration property + rootSrvConn.setConnectTimeout(2000) if (rootSrvConn.getResponseCode() != 200) { return false