diff --git a/chatmaild/src/chatmaild/filedict.py b/chatmaild/src/chatmaild/filedict.py new file mode 100644 index 00000000..bfcf42e9 --- /dev/null +++ b/chatmaild/src/chatmaild/filedict.py @@ -0,0 +1,37 @@ +import os +import logging +import marshal +import filelock +from contextlib import contextmanager + + +class FileDict: + """Concurrency-safe multi-reader-single-writer Persistent Dict.""" + + def __init__(self, path, timeout=5.0): + self.path = path + self.lock_path = path.with_name(path.name + ".lock") + self.timeout = timeout + + @contextmanager + def modify(self): + try: + with filelock.FileLock(self.lock_path, timeout=self.timeout): + data = self.read() + yield data + write_path = self.path.with_suffix(".tmp") + with write_path.open("wb") as f: + marshal.dump(data, f) + os.rename(write_path, self.path) + except filelock.Timeout: + logging.warning("could not obtain lock, removing: %r", self.lock_path) + os.remove(self.lock_path) + with self.modify() as d: + yield d + + def read(self): + try: + with self.path.open("rb") as f: + return marshal.load(f) + except FileNotFoundError: + return {} diff --git a/chatmaild/src/chatmaild/metadata.py b/chatmaild/src/chatmaild/metadata.py index 45f509c5..cf8daf77 100644 --- a/chatmaild/src/chatmaild/metadata.py +++ b/chatmaild/src/chatmaild/metadata.py @@ -12,9 +12,8 @@ import sys import logging import os import requests -import marshal -from contextlib import contextmanager -import filelock + +from .filedict import FileDict DICTPROXY_LOOKUP_CHAR = "L" @@ -27,38 +26,6 @@ DICTPROXY_TRANSACTION_CHARS = "SBC" METADATA_TOKEN_KEY = "devicetoken" -class PersistentDict: - """Concurrency-safe multi-reader-single-writer Persistent Dict.""" - - def __init__(self, path, timeout=5.0): - self.path = path - self.lock_path = path.with_name(path.name + ".lock") - self.timeout = timeout - - @contextmanager - def modify(self): - try: - with filelock.FileLock(self.lock_path, timeout=self.timeout): - data = self.get() - yield data - write_path = self.path.with_suffix(".tmp") - with write_path.open("wb") as f: - marshal.dump(data, f) - os.rename(write_path, self.path) - except filelock.Timeout: - logging.warning("could not obtain lock, removing: %r", self.lock_path) - os.remove(self.lock_path) - with self.modify() as d: - yield d - - def get(self): - try: - with self.path.open("rb") as f: - return marshal.load(f) - except FileNotFoundError: - return {} - - class Notifier: def __init__(self, vmail_dir): self.vmail_dir = vmail_dir @@ -68,7 +35,7 @@ class Notifier: mbox_path = self.vmail_dir.joinpath(mbox) if not mbox_path.exists(): mbox_path.mkdir() - return PersistentDict(mbox_path / "metadata.marshalled") + return FileDict(mbox_path / "metadata.marshalled") def add_token(self, mbox, token): with self.get_metadata_dict(mbox).modify() as data: @@ -78,7 +45,7 @@ class Notifier: if token not in tokens: tokens.append(token) - def del_token(self, mbox, token): + def remove_token(self, mbox, token): with self.get_metadata_dict(mbox).modify() as data: tokens = data.get(METADATA_TOKEN_KEY) if tokens: @@ -88,7 +55,7 @@ class Notifier: pass def get_tokens(self, mbox): - return self.get_metadata_dict(mbox).get().get(METADATA_TOKEN_KEY, []) + return self.get_metadata_dict(mbox).read().get(METADATA_TOKEN_KEY, []) def new_message_for_mbox(self, mbox): self.to_notify_queue.put(mbox) @@ -109,7 +76,7 @@ class Notifier: if response.status_code == 410: # 410 Gone status code # means the token is no longer valid. - self.del_token(mbox, token) + self.remove_token(mbox, token) def handle_dovecot_protocol(rfile, wfile, notifier): diff --git a/chatmaild/src/chatmaild/tests/test_filedict.py b/chatmaild/src/chatmaild/tests/test_filedict.py new file mode 100644 index 00000000..0f585349 --- /dev/null +++ b/chatmaild/src/chatmaild/tests/test_filedict.py @@ -0,0 +1,24 @@ +from chatmaild.filedict import FileDict + + +def test_basic(tmp_path): + fdict = FileDict(tmp_path.joinpath("metadata")) + assert fdict.read() == {} + with fdict.modify() as d: + d["devicetoken"] = [1, 2, 3] + d["456"] = 4.2 + new = fdict.read() + assert new["devicetoken"] == [1, 2, 3] + assert new["456"] == 4.2 + + +def test_dying_lock(tmp_path, caplog): + fdict1 = FileDict(tmp_path.joinpath("metadata")) + fdict2 = FileDict(tmp_path.joinpath("metadata"), timeout=0.1) + with fdict1.modify() as d: + with fdict2.modify() as d2: + d2["1"] = "2" + assert "could not obtain" in caplog.records[0].msg + d["1"] = "3" + assert fdict1.read()["1"] == "3" + assert fdict2.read()["1"] == "3" diff --git a/chatmaild/src/chatmaild/tests/test_metadata.py b/chatmaild/src/chatmaild/tests/test_metadata.py index a423ee2b..233c5846 100644 --- a/chatmaild/src/chatmaild/tests/test_metadata.py +++ b/chatmaild/src/chatmaild/tests/test_metadata.py @@ -5,7 +5,6 @@ from chatmaild.metadata import ( handle_dovecot_request, handle_dovecot_protocol, Notifier, - PersistentDict, ) @@ -30,12 +29,12 @@ def test_notifier_persistence(tmp_path): notifier1.add_token("user3@example.org", "456") assert notifier2.get_tokens("user1@example.org") == ["01234"] assert notifier2.get_tokens("user3@example.org") == ["456"] - notifier2.del_token("user1@example.org", "01234") + notifier2.remove_token("user1@example.org", "01234") assert not notifier1.get_tokens("user1@example.org") def test_notifier_delete_without_set(notifier): - notifier.del_token("user@example.org", "123") + notifier.remove_token("user@example.org", "123") assert not notifier.get_tokens("user@example.org") @@ -217,28 +216,3 @@ def test_notifier_thread_run_gone_removes_token(notifier): url, data, timeout = requests[1] assert data == "45678" assert notifier.get_tokens("user@example.org") == ["45678"] - - -class TestPersistentDict: - @pytest.fixture - def store(self, tmp_path): - return PersistentDict(tmp_path.joinpath("metadata")) - - def test_basic(self, store): - assert store.get() == {} - with store.modify() as d: - d["devicetoken"] = [1, 2, 3] - d["456"] = 4.2 - new = store.get() - assert new["devicetoken"] == [1, 2, 3] - assert new["456"] == 4.2 - - def test_dying_lock(self, tmp_path, caplog): - store1 = PersistentDict(tmp_path.joinpath("metadata")) - store2 = PersistentDict(tmp_path.joinpath("metadata"), timeout=0.1) - with store1.modify() as d: - with store2.modify() as d2: - d2["1"] = "2" - assert "could not obtain" in caplog.records[0].msg - d["1"] = "3" - assert store1.get()["1"] == "3"