Compare commits
	
		
			4 Commits
		
	
	
		
			v0.1.4
			...
			fix_userid
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 1010ef4e8f | ||
|  | b6bdebbc4a | ||
|  | 3524b4772a | ||
|  | 893473b236 | 
| @@ -30,15 +30,19 @@ sudo pip install git+https://github.com/ma1uta/matrix-synapse-rest-password-prov | |||||||
| If the command fail, double check that the python version still matches. If not, please let us know by opening an issue. | If the command fail, double check that the python version still matches. If not, please let us know by opening an issue. | ||||||
|  |  | ||||||
| ## Configure | ## Configure | ||||||
| Add or amend the `password_providers` entry like so: | Add or amend the `modules` entry like so: | ||||||
| ```yaml | ```yaml | ||||||
| password_providers: | modules: | ||||||
|   - module: "rest_auth_provider.RestAuthProvider" |   - module: "rest_auth_provider.RestAuthProvider" | ||||||
|     config: |     config: | ||||||
|       endpoint: "http://change.me.example.com:12345" |       endpoint: "http://change.me.example.com:12345" | ||||||
| ``` | ``` | ||||||
| Set `endpoint` to the value documented with the endpoint provider. | Set `endpoint` to the value documented with the endpoint provider. | ||||||
|  |  | ||||||
|  | **NOTE:** This requires Synapse 1.46 or later! If you migrate from the legacy `password_providers`, make sure | ||||||
|  | to remove the old `RestAuthProvider` entry. If the `password_providers` list is empty, you can also remove it completely or | ||||||
|  | comment it out. | ||||||
|  |  | ||||||
| ## Use | ## Use | ||||||
| 1. Install, configure, restart synapse | 1. Install, configure, restart synapse | ||||||
| 2. Try to login with a valid username and password for the endpoint configured | 2. Try to login with a valid username and password for the endpoint configured | ||||||
|   | |||||||
| @@ -20,18 +20,21 @@ | |||||||
| # | # | ||||||
|  |  | ||||||
| import logging | import logging | ||||||
| from twisted.internet import defer | from typing import Tuple, Optional, Callable, Awaitable | ||||||
|  |  | ||||||
| import requests | import requests | ||||||
| import json |  | ||||||
| import time | import time | ||||||
|  | import synapse | ||||||
|  | from synapse import module_api | ||||||
|  | from synapse.types import UserID | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
| class RestAuthProvider(object): | class RestAuthProvider(object): | ||||||
|  |  | ||||||
|     def __init__(self, config, account_handler): |     def __init__(self, config: dict, api: module_api): | ||||||
|         self.account_handler = account_handler |         self.account_handler = api | ||||||
|  |  | ||||||
|         if not config.endpoint: |         if not config.endpoint: | ||||||
|             raise RuntimeError('Missing endpoint config') |             raise RuntimeError('Missing endpoint config') | ||||||
| @@ -43,8 +46,37 @@ class RestAuthProvider(object): | |||||||
|         logger.info('Endpoint: %s', self.endpoint) |         logger.info('Endpoint: %s', self.endpoint) | ||||||
|         logger.info('Enforce lowercase username during registration: %s', self.regLower) |         logger.info('Enforce lowercase username during registration: %s', self.regLower) | ||||||
|  |  | ||||||
|     @defer.inlineCallbacks |         # register an auth callback handler | ||||||
|     def check_password(self, user_id, password): |         # see https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html | ||||||
|  |         api.register_password_auth_provider_callbacks( | ||||||
|  |             auth_checkers={ | ||||||
|  |                 ("m.login.password", ("password",)): self.check_m_login_password | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     async def check_m_login_password(self, username: str, | ||||||
|  |                                      login_type: str, | ||||||
|  |                                      login_dict: "synapse.module_api.JsonDict") -> Optional[ | ||||||
|  |         Tuple[ | ||||||
|  |             str, | ||||||
|  |             Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], | ||||||
|  |         ] | ||||||
|  |     ]: | ||||||
|  |         if login_type != "m.login.password": | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         # get the complete MXID | ||||||
|  |         mxid = self.account_handler.get_qualified_user_id(username) | ||||||
|  |  | ||||||
|  |         # check if the password is valid with the old function | ||||||
|  |         password_valid = await self.check_password(mxid, login_dict.get("password")) | ||||||
|  |  | ||||||
|  |         if password_valid: | ||||||
|  |             return mxid, None | ||||||
|  |         else: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     async def check_password(self, user_id, password): | ||||||
|         logger.info("Got password check for " + user_id) |         logger.info("Got password check for " + user_id) | ||||||
|         data = {'user': {'id': user_id, 'password': password}} |         data = {'user': {'id': user_id, 'password': password}} | ||||||
|         r = requests.post(self.endpoint + '/_matrix-internal/identity/v1/check_credentials', json=data) |         r = requests.post(self.endpoint + '/_matrix-internal/identity/v1/check_credentials', json=data) | ||||||
| @@ -58,20 +90,22 @@ class RestAuthProvider(object): | |||||||
|         auth = r["auth"] |         auth = r["auth"] | ||||||
|         if not auth["success"]: |         if not auth["success"]: | ||||||
|             logger.info("User not authenticated") |             logger.info("User not authenticated") | ||||||
|             defer.returnValue(False) |             return False | ||||||
|  |  | ||||||
|  |         types_user_id = UserID.from_string(user_id) | ||||||
|         localpart = user_id.split(":", 1)[0][1:] |         localpart = user_id.split(":", 1)[0][1:] | ||||||
|  |         domain = user_id.split(":", 1)[1][1:] | ||||||
|         logger.info("User %s authenticated", user_id) |         logger.info("User %s authenticated", user_id) | ||||||
|  |  | ||||||
|         registration = False |         registration = False | ||||||
|         if not (yield self.account_handler.check_user_exists(user_id)): |         if not (await self.account_handler.check_user_exists(user_id)): | ||||||
|             logger.info("User %s does not exist yet, creating...", user_id) |             logger.info("User %s does not exist yet, creating...", user_id) | ||||||
|  |  | ||||||
|             if localpart != localpart.lower() and self.regLower: |             if localpart != localpart.lower() and self.regLower: | ||||||
|                 logger.info('User %s was cannot be created due to username lowercase policy', localpart) |                 logger.info('User %s was cannot be created due to username lowercase policy', localpart) | ||||||
|                 defer.returnValue(False) |                 return False | ||||||
|  |  | ||||||
|             user_id, access_token = (yield self.account_handler.register(localpart=localpart)) |             user_id, access_token = (await self.account_handler.register(localpart=localpart)) | ||||||
|             registration = True |             registration = True | ||||||
|             logger.info("Registration based on REST data was successful for %s", user_id) |             logger.info("Registration based on REST data was successful for %s", user_id) | ||||||
|         else: |         else: | ||||||
| @@ -81,16 +115,12 @@ class RestAuthProvider(object): | |||||||
|             logger.info("Handling profile data") |             logger.info("Handling profile data") | ||||||
|             profile = auth["profile"] |             profile = auth["profile"] | ||||||
|  |  | ||||||
|             # fixme: temporary fix |             store = self.account_handler._hs.get_profile_handler().store | ||||||
|             try: |  | ||||||
|                 store = yield self.account_handler._hs.get_profile_handler().store  # for synapse >= 1.9.0 |  | ||||||
|             except AttributeError: |  | ||||||
|                 store = yield self.account_handler.hs.get_profile_handler().store   # for synapse < 1.9.0 |  | ||||||
|  |  | ||||||
|             if "display_name" in profile and ((registration and self.config.setNameOnRegister) or (self.config.setNameOnLogin)): |             if "display_name" in profile and ((registration and self.config.setNameOnRegister) or (self.config.setNameOnLogin)): | ||||||
|                 display_name = profile["display_name"] |                 display_name = profile["display_name"] | ||||||
|                 logger.info("Setting display name to '%s' based on profile data", display_name) |                 logger.info("Setting display name to '%s' based on profile data", display_name) | ||||||
|                 yield store.set_profile_displayname(localpart, display_name) |                 await store.set_profile_displayname(types_user_id, display_name) | ||||||
|             else: |             else: | ||||||
|                 logger.info("Display name was not set because it was not given or policy restricted it") |                 logger.info("Display name was not set because it was not given or policy restricted it") | ||||||
|  |  | ||||||
| @@ -106,9 +136,9 @@ class RestAuthProvider(object): | |||||||
|                         logger.info("Looking for 3PID %s:%s in user profile", medium, address) |                         logger.info("Looking for 3PID %s:%s in user profile", medium, address) | ||||||
|  |  | ||||||
|                         validated_at = time_msec() |                         validated_at = time_msec() | ||||||
|                         if not (yield store.get_user_id_by_threepid(medium, address)): |                         if not (await store.get_user_id_by_threepid(medium, address)): | ||||||
|                             logger.info("3PID is not present, adding") |                             logger.info("3PID is not present, adding") | ||||||
|                             yield store.user_add_threepid( |                             await store.user_add_threepid( | ||||||
|                                 user_id, |                                 user_id, | ||||||
|                                 medium, |                                 medium, | ||||||
|                                 address, |                                 address, | ||||||
| @@ -119,12 +149,12 @@ class RestAuthProvider(object): | |||||||
|                             logger.info("3PID is present, skipping") |                             logger.info("3PID is present, skipping") | ||||||
|  |  | ||||||
|                     if (self.config.replaceThreepid): |                     if (self.config.replaceThreepid): | ||||||
|                         for threepid in (yield store.user_get_threepids(user_id)): |                         for threepid in (await store.user_get_threepids(user_id)): | ||||||
|                             medium = threepid["medium"].lower() |                             medium = threepid["medium"].lower() | ||||||
|                             address = threepid["address"].lower() |                             address = threepid["address"].lower() | ||||||
|                             if {"medium": medium, "address": address} not in external_3pids: |                             if {"medium": medium, "address": address} not in external_3pids: | ||||||
|                                 logger.info("3PID is not present in external datastore, deleting") |                                 logger.info("3PID is not present in external datastore, deleting") | ||||||
|                                 yield store.user_delete_threepid( |                                 await store.user_delete_threepid( | ||||||
|                                     user_id, |                                     user_id, | ||||||
|                                     medium, |                                     medium, | ||||||
|                                     address |                                     address | ||||||
| @@ -135,7 +165,7 @@ class RestAuthProvider(object): | |||||||
|         else: |         else: | ||||||
|             logger.info("No profile data") |             logger.info("No profile data") | ||||||
|  |  | ||||||
|         defer.returnValue(True) |         return True | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def parse_config(config): |     def parse_config(config): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user