mirror of
https://github.com/chatmail/relay.git
synced 2026-05-20 12:58:04 +00:00
remove neccessity for FileLock on set_password
This commit is contained in:
@@ -3,7 +3,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from threading import RLock
|
|
||||||
|
|
||||||
from .config import Config, read_config
|
from .config import Config, read_config
|
||||||
from .dictproxy import DictProxy
|
from .dictproxy import DictProxy
|
||||||
@@ -86,10 +85,6 @@ class AuthDictProxy(DictProxy):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# We serialize all password-writes in the single doveauth process
|
|
||||||
# so that threads can not mangle the password when writing.
|
|
||||||
# Setting a password is a quite rare event anyway.
|
|
||||||
self._password_write_lock = RLock()
|
|
||||||
|
|
||||||
def handle_lookup(self, parts):
|
def handle_lookup(self, parts):
|
||||||
# Dovecot <2.3.17 has only one part,
|
# Dovecot <2.3.17 has only one part,
|
||||||
@@ -145,8 +140,7 @@ class AuthDictProxy(DictProxy):
|
|||||||
if not is_allowed_to_create(self.config, addr, cleartext_password):
|
if not is_allowed_to_create(self.config, addr, cleartext_password):
|
||||||
return
|
return
|
||||||
|
|
||||||
with self._password_write_lock:
|
user.set_password(encrypt_password(cleartext_password))
|
||||||
user.set_password(encrypt_password(cleartext_password))
|
|
||||||
print(f"Created address: {user}", file=sys.stderr)
|
print(f"Created address: {user}", file=sys.stderr)
|
||||||
return user.get_userdb_dict()
|
return user.get_userdb_dict()
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from random import randint
|
||||||
|
|
||||||
import filelock
|
import filelock
|
||||||
|
|
||||||
@@ -34,3 +35,10 @@ class FileDict:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logging.warning(f"corrupt serialization state at: {self.path!r}")
|
logging.warning(f"corrupt serialization state at: {self.path!r}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def write_bytes_atomic(path, content):
|
||||||
|
rint = randint(0, 10000000)
|
||||||
|
tmp = path.with_name(path.name + f".tmp-{rint}")
|
||||||
|
tmp.write_bytes(content)
|
||||||
|
os.rename(tmp, path)
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from chatmaild.filedict import FileDict
|
import threading
|
||||||
|
|
||||||
|
from chatmaild.filedict import FileDict, write_bytes_atomic
|
||||||
|
|
||||||
|
|
||||||
def test_basic(tmp_path):
|
def test_basic(tmp_path):
|
||||||
@@ -17,3 +19,21 @@ def test_bad_marshal_file(tmp_path, caplog):
|
|||||||
fdict1.path.write_bytes(b"l12k3l12k3l")
|
fdict1.path.write_bytes(b"l12k3l12k3l")
|
||||||
assert fdict1.read() == {}
|
assert fdict1.read() == {}
|
||||||
assert "corrupt" in caplog.records[0].msg
|
assert "corrupt" in caplog.records[0].msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_bytes_atomic_concurrent(tmp_path):
|
||||||
|
p = tmp_path.joinpath("somefile.ext")
|
||||||
|
write_bytes_atomic(p, b"hello")
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(30):
|
||||||
|
content = f"hello{i}".encode("ascii")
|
||||||
|
t = threading.Thread(target=lambda: write_bytes_atomic(p, content))
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert p.read_text().strip() != "hello"
|
||||||
|
assert len(list(p.parent.iterdir())) == 1
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from chatmaild.filedict import write_bytes_atomic
|
||||||
|
|
||||||
|
|
||||||
def get_daytimestamp(timestamp) -> int:
|
def get_daytimestamp(timestamp) -> int:
|
||||||
return int(timestamp) // 86400 * 86400
|
return int(timestamp) // 86400 * 86400
|
||||||
@@ -37,15 +39,14 @@ class User:
|
|||||||
def set_password(self, enc_password):
|
def set_password(self, enc_password):
|
||||||
"""Set the specified password for this user.
|
"""Set the specified password for this user.
|
||||||
|
|
||||||
NOTE that this method is not multi-thread/process safe.
|
This method can be called concurrently
|
||||||
The caller has to ensure only a single thread writes to the same
|
but there is no guarantee which of the password-set calls will win.
|
||||||
user's password file.
|
|
||||||
"""
|
"""
|
||||||
self.maildir.mkdir(exist_ok=True, parents=True)
|
self.maildir.mkdir(exist_ok=True, parents=True)
|
||||||
password = enc_password.encode("ascii")
|
password = enc_password.encode("ascii")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.password_path.write_bytes(password)
|
write_bytes_atomic(self.password_path, password)
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
if not self.addr.startswith("echo@"):
|
if not self.addr.startswith("echo@"):
|
||||||
logging.error(f"could not write password for: {self.addr}")
|
logging.error(f"could not write password for: {self.addr}")
|
||||||
|
|||||||
Reference in New Issue
Block a user