separate notification thread into own class, and test start_notification_threads

This commit is contained in:
holger krekel
2024-03-31 01:23:02 +01:00
parent d313bea97f
commit a31d998e67
2 changed files with 70 additions and 47 deletions

View File

@@ -29,8 +29,9 @@ METADATA_TOKEN_KEY = "devicetoken"
class Notifier: class Notifier:
CONNECTION_TIMEOUT = 60.0 # seconds URL = "https://notifications.delta.chat/notify"
NOTIFICATION_RETRY_DELAY = 8.0 # seconds, with exponential backoff CONNECTION_TIMEOUT = 60.0 # seconds until http-request is given up
NOTIFICATION_RETRY_DELAY = 8.0 # seconds with exponential backoff
MAX_NUMBER_OF_TRIES = 6 MAX_NUMBER_OF_TRIES = 6
# exponential backoff means we try for 8^5 seconds, approximately 10 hours # exponential backoff means we try for 8^5 seconds, approximately 10 hours
@@ -39,7 +40,7 @@ class Notifier:
self.notification_dir = vmail_dir / "pending_notifications" self.notification_dir = vmail_dir / "pending_notifications"
if not self.notification_dir.exists(): if not self.notification_dir.exists():
self.notification_dir.mkdir() self.notification_dir.mkdir()
self.retry_queues = [PriorityQueue() for i in range(self.MAX_NUMBER_OF_TRIES)] self.retry_queues = [PriorityQueue() for _ in range(self.MAX_NUMBER_OF_TRIES)]
def get_metadata_dict(self, addr): def get_metadata_dict(self, addr):
return FileDict(self.vmail_dir / addr / "metadata.json") return FileDict(self.vmail_dir / addr / "metadata.json")
@@ -55,10 +56,8 @@ class Notifier:
def remove_token_from_addr(self, addr, token): def remove_token_from_addr(self, addr, token):
with self.get_metadata_dict(addr).modify() as data: with self.get_metadata_dict(addr).modify() as data:
tokens = data.get(METADATA_TOKEN_KEY, []) tokens = data.get(METADATA_TOKEN_KEY, [])
try: if token in tokens:
tokens.remove(token) tokens.remove(token)
except ValueError:
pass
def get_tokens_for_addr(self, addr): def get_tokens_for_addr(self, addr):
return self.get_metadata_dict(addr).read().get(METADATA_TOKEN_KEY, []) return self.get_metadata_dict(addr).read().get(METADATA_TOKEN_KEY, [])
@@ -68,15 +67,15 @@ class Notifier:
self.notification_dir.joinpath(token).write_text(addr) self.notification_dir.joinpath(token).write_text(addr)
self.add_token_for_retry(token) self.add_token_for_retry(token)
def add_token_for_retry(self, token, numtries=0): def add_token_for_retry(self, token, retry_num=0):
if numtries >= self.MAX_NUMBER_OF_TRIES: if retry_num >= self.MAX_NUMBER_OF_TRIES:
return False return False
when = time.time() when = time.time()
if numtries > 0: if retry_num > 0:
# backup exponentially with number of retries # backup exponentially with number of retries
when += pow(self.NOTIFICATION_RETRY_DELAY, numtries) when += pow(self.NOTIFICATION_RETRY_DELAY, retry_num)
self.retry_queues[numtries].put((when, token)) self.retry_queues[retry_num].put((when, token))
return True return True
def requeue_persistent_pending_tokens(self): def requeue_persistent_pending_tokens(self):
@@ -85,53 +84,69 @@ class Notifier:
def start_notification_threads(self): def start_notification_threads(self):
self.requeue_persistent_pending_tokens() self.requeue_persistent_pending_tokens()
threads = {}
for retry_num in range(len(self.retry_queues)):
num_threads = {0: 4}.get(retry_num, 2)
threads[retry_num] = []
for _ in range(num_threads):
threads[retry_num].append(NotifyThread(self, retry_num))
threads[retry_num][-1].start()
return threads
# start a thread for each retry-queue bucket
for numtries in range(len(self.retry_queues)):
t = Thread(target=self.thread_retry_loop, args=(numtries,))
t.setDaemon(True)
t.start()
def thread_retry_loop(self, numtries): class NotifyThread(Thread):
def __init__(self, notifier, retry_num):
super().__init__(daemon=True)
self.notifier = notifier
self.retry_num = retry_num
def stop(self):
self.notifier.retry_queues[self.retry_num].put((None, None))
def run(self):
requests_session = requests.Session() requests_session = requests.Session()
while True: while self.retry_one(requests_session):
self.thread_retry_one(requests_session, numtries) pass
def thread_retry_one(self, requests_session, numtries, sleepfunc=time.sleep): def retry_one(self, requests_session, sleep=time.sleep):
retry_queue = self.retry_queues[numtries] # takes the next token from the per-retry-number PriorityQueue
when, token = retry_queue.get() # which is ordered by "when" (as set by add_token_for_retry()).
# If the request to notification server fails the token is
# queued to the next retry-number's PriorityQueue
# until it finally is dropped if MAX_NUMBER_OF_TRIES is exceeded
when, token = self.notifier.retry_queues[self.retry_num].get()
if when is None:
return False
wait_time = when - time.time() wait_time = when - time.time()
if wait_time > 0: if wait_time > 0:
sleepfunc(wait_time) sleep(wait_time)
self.notify_one(requests_session, token, numtries) self.perform_request_to_notification_server(requests_session, token)
return True
def notify_one(self, requests_session, token, numtries=0): def perform_request_to_notification_server(self, requests_session, token):
token_path = self.notification_dir.joinpath(token) token_path = self.notifier.notification_dir.joinpath(token)
try: try:
response = requests_session.post( timeout = self.notifier.CONNECTION_TIMEOUT
"https://notifications.delta.chat/notify", res = requests_session.post(self.notifier.URL, data=token, timeout=timeout)
data=token,
timeout=self.CONNECTION_TIMEOUT,
)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
response = e res = e
else: else:
if response.status_code in (200, 410): if res.status_code in (200, 410):
if response.status_code == 410: if res.status_code == 410:
# 410 Gone: means the token is no longer valid. # 410 Gone: means the token is no longer valid.
try: try:
addr = token_path.read_text() addr = token_path.read_text()
except FileNotFoundError: except FileNotFoundError:
logging.warning("no address for token %r:", token) logging.warning("no address for token %r:", token)
return return
self.remove_token_from_addr(addr, token) self.notifier.remove_token_from_addr(addr, token)
token_path.unlink(missing_ok=True) token_path.unlink(missing_ok=True)
return return
logging.warning("Notification request failed: %r", response) logging.warning("Notification request failed: %r", res)
if not self.add_token_for_retry(token, numtries=numtries + 1): if not self.notifier.add_token_for_retry(token, retry_num=self.retry_num + 1):
token_path.unlink(missing_ok=True) token_path.unlink(missing_ok=True)
logging.warning("dropping token after %d tries: %r", numtries, token) logging.warning("dropping token after %d tries: %r", self.retry_num, token)
def handle_dovecot_protocol(rfile, wfile, notifier): def handle_dovecot_protocol(rfile, wfile, notifier):
@@ -214,6 +229,7 @@ def main():
return 1 return 1
notifier = Notifier(vmail_dir) notifier = Notifier(vmail_dir)
notifier.start_notification_threads()
class Handler(StreamRequestHandler): class Handler(StreamRequestHandler):
def handle(self): def handle(self):
@@ -228,8 +244,6 @@ def main():
except FileNotFoundError: except FileNotFoundError:
pass pass
notifier.start_notification_threads()
with ThreadedUnixStreamServer(socket, Handler) as server: with ThreadedUnixStreamServer(socket, Handler) as server:
try: try:
server.serve_forever() server.serve_forever()

View File

@@ -6,6 +6,7 @@ from chatmaild.metadata import (
handle_dovecot_request, handle_dovecot_request,
handle_dovecot_protocol, handle_dovecot_protocol,
Notifier, Notifier,
NotifyThread,
) )
@@ -173,7 +174,7 @@ def test_notifier_thread_firstrun(notifier, testaddr):
reqmock = get_mocked_requests([200]) reqmock = get_mocked_requests([200])
notifier.add_token_to_addr(testaddr, "01234") notifier.add_token_to_addr(testaddr, "01234")
notifier.new_message_for_addr(testaddr) notifier.new_message_for_addr(testaddr)
notifier.thread_retry_one(reqmock, numtries=0) NotifyThread(notifier, retry_num=0).retry_one(reqmock)
url, data, timeout = reqmock.requests[0] url, data, timeout = reqmock.requests[0]
assert data == "01234" assert data == "01234"
assert notifier.get_tokens_for_addr(testaddr) == ["01234"] assert notifier.get_tokens_for_addr(testaddr) == ["01234"]
@@ -185,7 +186,7 @@ def test_notifier_thread_run(notifier, testaddr):
notifier.add_token_to_addr(testaddr, "01234") notifier.add_token_to_addr(testaddr, "01234")
notifier.new_message_for_addr(testaddr) notifier.new_message_for_addr(testaddr)
reqmock = get_mocked_requests([200]) reqmock = get_mocked_requests([200])
notifier.thread_retry_one(reqmock, numtries=0) NotifyThread(notifier, retry_num=0).retry_one(reqmock)
url, data, timeout = reqmock.requests[0] url, data, timeout = reqmock.requests[0]
assert data == "01234" assert data == "01234"
assert notifier.get_tokens_for_addr(testaddr) == ["01234"] assert notifier.get_tokens_for_addr(testaddr) == ["01234"]
@@ -203,7 +204,7 @@ def test_notifier_thread_connection_failures(notifier, testaddr, status, caplog)
caplog.clear() caplog.clear()
reqmock = get_mocked_requests([status]) reqmock = get_mocked_requests([status])
sleep_calls = [] sleep_calls = []
notifier.thread_retry_one(reqmock, numtries=i, sleepfunc=sleep_calls.append) NotifyThread(notifier, retry_num=i).retry_one(reqmock, sleep=sleep_calls.append)
assert notifier.retry_queues[i].qsize() == 0 assert notifier.retry_queues[i].qsize() == 0
assert "request failed" in caplog.records[0].msg assert "request failed" in caplog.records[0].msg
if i > 0: if i > 0:
@@ -218,14 +219,22 @@ def test_notifier_thread_connection_failures(notifier, testaddr, status, caplog)
assert notifier.retry_queues[0].qsize() == 0 assert notifier.retry_queues[0].qsize() == 0
def test_start_and_stop_notification_threads(notifier, testaddr):
threads = notifier.start_notification_threads()
for retry_num, threadlist in threads.items():
for t in threadlist:
t.stop()
t.join()
def test_multi_device_notifier(notifier, testaddr): def test_multi_device_notifier(notifier, testaddr):
notifier.add_token_to_addr(testaddr, "01234") notifier.add_token_to_addr(testaddr, "01234")
notifier.add_token_to_addr(testaddr, "56789") notifier.add_token_to_addr(testaddr, "56789")
notifier.new_message_for_addr(testaddr) notifier.new_message_for_addr(testaddr)
reqmock = get_mocked_requests([200, 200]) reqmock = get_mocked_requests([200, 200])
notifier.thread_retry_one(reqmock, numtries=0) NotifyThread(notifier, retry_num=0).retry_one(reqmock)
notifier.thread_retry_one(reqmock, numtries=0) NotifyThread(notifier, retry_num=0).retry_one(reqmock)
assert notifier.retry_queues[0].qsize() == 0 assert notifier.retry_queues[0].qsize() == 0
assert notifier.retry_queues[1].qsize() == 0 assert notifier.retry_queues[1].qsize() == 0
url, data, timeout = reqmock.requests[0] url, data, timeout = reqmock.requests[0]
@@ -240,8 +249,8 @@ def test_notifier_thread_run_gone_removes_token(notifier, testaddr):
notifier.add_token_to_addr(testaddr, "45678") notifier.add_token_to_addr(testaddr, "45678")
notifier.new_message_for_addr(testaddr) notifier.new_message_for_addr(testaddr)
reqmock = get_mocked_requests([410, 200]) reqmock = get_mocked_requests([410, 200])
notifier.thread_retry_one(reqmock, numtries=0) NotifyThread(notifier, retry_num=0).retry_one(reqmock)
notifier.thread_retry_one(reqmock, numtries=0) NotifyThread(notifier, retry_num=0).retry_one(reqmock)
url, data, timeout = reqmock.requests[0] url, data, timeout = reqmock.requests[0]
assert data == "01234" assert data == "01234"
url, data, timeout = reqmock.requests[1] url, data, timeout = reqmock.requests[1]