fix: use separate transaction storage for each DictProxy handler

DictProxy can have transactions with the same name
(most frequently `1`) processed in parallel.
Dovecot expects that transaction names on each connection
are independent.
This commit is contained in:
link2xt
2024-07-31 17:18:22 +00:00
parent dcab097e00
commit 1f1d1fdf59
6 changed files with 51 additions and 40 deletions

View File

@@ -4,21 +4,24 @@ from socketserver import StreamRequestHandler, ThreadingUnixStreamServer
class DictProxy: class DictProxy:
def __init__(self):
self.transactions = {}
def loop_forever(self, rfile, wfile): def loop_forever(self, rfile, wfile):
# Transaction storage is local to each handler loop.
# Dovecot reuses transaction IDs across connections,
# starting transaction with the name `1`
# on two different connections to the same proxy sometimes.
transactions = {}
while True: while True:
msg = rfile.readline().strip().decode() msg = rfile.readline().strip().decode()
if not msg: if not msg:
break break
res = self.handle_dovecot_request(msg) res = self.handle_dovecot_request(msg, transactions)
if res: if res:
wfile.write(res.encode("ascii")) wfile.write(res.encode("ascii"))
wfile.flush() wfile.flush()
def handle_dovecot_request(self, msg): def handle_dovecot_request(self, msg, transactions):
# see https://doc.dovecot.org/developer_manual/design/dict_protocol/#dovecot-dict-protocol # see https://doc.dovecot.org/developer_manual/design/dict_protocol/#dovecot-dict-protocol
short_command = msg[0] short_command = msg[0]
parts = msg[1:].split("\t") parts = msg[1:].split("\t")
@@ -37,11 +40,11 @@ class DictProxy:
transaction_id = parts[0] transaction_id = parts[0]
if short_command == "B": if short_command == "B":
return self.handle_begin_transaction(transaction_id, parts) return self.handle_begin_transaction(transaction_id, parts, transactions)
elif short_command == "C": elif short_command == "C":
return self.handle_commit_transaction(transaction_id, parts) return self.handle_commit_transaction(transaction_id, parts, transactions)
elif short_command == "S": elif short_command == "S":
return self.handle_set(transaction_id, parts) return self.handle_set(transaction_id, parts, transactions)
def handle_lookup(self, parts): def handle_lookup(self, parts):
logging.warning(f"lookup ignored: {parts!r}") logging.warning(f"lookup ignored: {parts!r}")
@@ -52,19 +55,19 @@ class DictProxy:
# If we don't return empty line Dovecot will timeout. # If we don't return empty line Dovecot will timeout.
return "\n" return "\n"
def handle_begin_transaction(self, transaction_id, parts): def handle_begin_transaction(self, transaction_id, parts, transactions):
addr = parts[1] addr = parts[1]
self.transactions[transaction_id] = dict(addr=addr, res="O\n") transactions[transaction_id] = dict(addr=addr, res="O\n")
def handle_set(self, transaction_id, parts): def handle_set(self, transaction_id, parts, transactions):
# For documentation on key structure see # For documentation on key structure see
# https://github.com/dovecot/core/blob/main/src/lib-storage/mailbox-attribute.h # https://github.com/dovecot/core/blob/main/src/lib-storage/mailbox-attribute.h
self.transactions[transaction_id]["res"] = "F\n" transactions[transaction_id]["res"] = "F\n"
def handle_commit_transaction(self, transaction_id, parts): def handle_commit_transaction(self, transaction_id, parts, transactions):
# return whatever "set" command(s) set as result. # return whatever "set" command(s) set as result.
return self.transactions.pop(transaction_id)["res"] return transactions.pop(transaction_id)["res"]
def serve_forever_from_socket(self, socket): def serve_forever_from_socket(self, socket):
dictproxy = self dictproxy = self

View File

@@ -9,10 +9,10 @@ class LastLoginDictProxy(DictProxy):
super().__init__() super().__init__()
self.config = config self.config = config
def handle_set(self, transaction_id, parts): def handle_set(self, transaction_id, parts, transactions):
keyname = parts[1].split("/") keyname = parts[1].split("/")
value = parts[2] if len(parts) > 2 else "" value = parts[2] if len(parts) > 2 else ""
addr = self.transactions[transaction_id]["addr"] addr = transactions[transaction_id]["addr"]
if keyname[0] == "shared" and keyname[1] == "last-login": if keyname[0] == "shared" and keyname[1] == "last-login":
if addr.startswith("echo@"): if addr.startswith("echo@"):
return return
@@ -22,7 +22,7 @@ class LastLoginDictProxy(DictProxy):
user.set_last_login_timestamp(timestamp) user.set_last_login_timestamp(timestamp)
else: else:
# Transaction failed. # Transaction failed.
self.transactions[transaction_id]["res"] = "F\n" transactions[transaction_id]["res"] = "F\n"
def main(): def main():

View File

@@ -62,12 +62,12 @@ class MetadataDictProxy(DictProxy):
logging.warning(f"lookup ignored: {parts!r}") logging.warning(f"lookup ignored: {parts!r}")
return "N\n" return "N\n"
def handle_set(self, transaction_id, parts): def handle_set(self, transaction_id, parts, transactions):
# For documentation on key structure see # For documentation on key structure see
# https://github.com/dovecot/core/blob/main/src/lib-storage/mailbox-attribute.h # https://github.com/dovecot/core/blob/main/src/lib-storage/mailbox-attribute.h
keyname = parts[1].split("/") keyname = parts[1].split("/")
value = parts[2] if len(parts) > 2 else "" value = parts[2] if len(parts) > 2 else ""
addr = self.transactions[transaction_id]["addr"] addr = transactions[transaction_id]["addr"]
if keyname[0] == "priv" and keyname[2] == self.metadata.DEVICETOKEN_KEY: if keyname[0] == "priv" and keyname[2] == self.metadata.DEVICETOKEN_KEY:
self.metadata.add_token_to_addr(addr, value) self.metadata.add_token_to_addr(addr, value)
elif keyname[0] == "priv" and keyname[2] == "messagenew": elif keyname[0] == "priv" and keyname[2] == "messagenew":
@@ -75,10 +75,10 @@ class MetadataDictProxy(DictProxy):
else: else:
# Transaction failed. # Transaction failed.
try: try:
self.transactions[transaction_id]["res"] = "F\n" transactions[transaction_id]["res"] = "F\n"
except KeyError: except KeyError:
logging.error( logging.error(
f"could not mark tx as failed: {transaction_id} {self.transactions}" f"could not mark tx as failed: {transaction_id} {transactions}"
) )

View File

@@ -72,12 +72,13 @@ def test_nocreate_file(monkeypatch, tmpdir, dictproxy):
def test_handle_dovecot_request(dictproxy): def test_handle_dovecot_request(dictproxy):
transactions = {}
# Test that password can contain ", ', \ and / # Test that password can contain ", ', \ and /
msg = ( msg = (
'Lshared/passdb/laksjdlaksjdlak\\\\sjdlk\\"12j\\\'3l1/k2j3123"' 'Lshared/passdb/laksjdlaksjdlak\\\\sjdlk\\"12j\\\'3l1/k2j3123"'
"some42123@chat.example.org\tsome42123@chat.example.org" "some42123@chat.example.org\tsome42123@chat.example.org"
) )
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, transactions)
assert res assert res
assert res[0] == "O" and res.endswith("\n") assert res[0] == "O" and res.endswith("\n")
userdata = json.loads(res[1:].strip()) userdata = json.loads(res[1:].strip())

View File

@@ -12,28 +12,30 @@ def test_handle_dovecot_request_last_login(testaddr, example_config):
authproxy = AuthDictProxy(config=example_config) authproxy = AuthDictProxy(config=example_config)
authproxy.lookup_passdb(testaddr, "1l2k3j1l2k3jl123") authproxy.lookup_passdb(testaddr, "1l2k3j1l2k3jl123")
dictproxy_transactions = {}
# Begin transaction # Begin transaction
tx = "1111" tx = "1111"
msg = f"B{tx}\t{testaddr}" msg = f"B{tx}\t{testaddr}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, dictproxy_transactions)
assert not res assert not res
assert dictproxy.transactions == {tx: dict(addr=testaddr, res="O\n")} assert dictproxy_transactions == {tx: dict(addr=testaddr, res="O\n")}
# set last-login info for user # set last-login info for user
user = dictproxy.config.get_user(testaddr) user = dictproxy.config.get_user(testaddr)
timestamp = int(time.time()) timestamp = int(time.time())
msg = f"S{tx}\tshared/last-login/{testaddr}\t{timestamp}" msg = f"S{tx}\tshared/last-login/{testaddr}\t{timestamp}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, dictproxy_transactions)
assert not res assert not res
assert len(dictproxy.transactions) == 1 assert len(dictproxy_transactions) == 1
read_timestamp = user.get_last_login_timestamp() read_timestamp = user.get_last_login_timestamp()
assert read_timestamp == timestamp // 86400 * 86400 assert read_timestamp == timestamp // 86400 * 86400
# finish transaction # finish transaction
msg = f"C{tx}" msg = f"C{tx}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, dictproxy_transactions)
assert res == "O\n" assert res == "O\n"
assert len(dictproxy.transactions) == 0 assert len(dictproxy_transactions) == 0
def test_handle_dovecot_request_last_login_echobot(example_config): def test_handle_dovecot_request_last_login_echobot(example_config):
@@ -44,17 +46,19 @@ def test_handle_dovecot_request_last_login_echobot(example_config):
authproxy.lookup_passdb(testaddr, "ignore") authproxy.lookup_passdb(testaddr, "ignore")
user = dictproxy.config.get_user(testaddr) user = dictproxy.config.get_user(testaddr)
transactions = {}
# set last-login info for user # set last-login info for user
tx = "1111" tx = "1111"
msg = f"B{tx}\t{testaddr}" msg = f"B{tx}\t{testaddr}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, transactions)
assert not res assert not res
assert dictproxy.transactions == {tx: dict(addr=testaddr, res="O\n")} assert transactions == {tx: dict(addr=testaddr, res="O\n")}
timestamp = int(time.time()) timestamp = int(time.time())
msg = f"S{tx}\tshared/last-login/{testaddr}\t{timestamp}" msg = f"S{tx}\tshared/last-login/{testaddr}\t{timestamp}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, transactions)
assert not res assert not res
assert len(dictproxy.transactions) == 1 assert len(transactions) == 1
read_timestamp = user.get_last_login_timestamp() read_timestamp = user.get_last_login_timestamp()
assert read_timestamp is None assert read_timestamp is None

View File

@@ -88,42 +88,45 @@ def test_notifier_remove_without_set(metadata, testaddr):
def test_handle_dovecot_request_lookup_fails(dictproxy, testaddr): def test_handle_dovecot_request_lookup_fails(dictproxy, testaddr):
res = dictproxy.handle_dovecot_request(f"Lpriv/123/chatmail\t{testaddr}") transactions = {}
res = dictproxy.handle_dovecot_request(
f"Lpriv/123/chatmail\t{testaddr}", transactions
)
assert res == "N\n" assert res == "N\n"
def test_handle_dovecot_request_happy_path(dictproxy, testaddr, token): def test_handle_dovecot_request_happy_path(dictproxy, testaddr, token):
metadata = dictproxy.metadata metadata = dictproxy.metadata
transactions = dictproxy.transactions transactions = {}
notifier = dictproxy.notifier notifier = dictproxy.notifier
# set device token in a transaction # set device token in a transaction
tx = "1111" tx = "1111"
msg = f"B{tx}\t{testaddr}" msg = f"B{tx}\t{testaddr}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, transactions)
assert not res and not metadata.get_tokens_for_addr(testaddr) assert not res and not metadata.get_tokens_for_addr(testaddr)
assert transactions == {tx: dict(addr=testaddr, res="O\n")} assert transactions == {tx: dict(addr=testaddr, res="O\n")}
msg = f"S{tx}\tpriv/guid00/devicetoken\t{token}" msg = f"S{tx}\tpriv/guid00/devicetoken\t{token}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, transactions)
assert not res assert not res
assert len(transactions) == 1 assert len(transactions) == 1
assert metadata.get_tokens_for_addr(testaddr) == [token] assert metadata.get_tokens_for_addr(testaddr) == [token]
msg = f"C{tx}" msg = f"C{tx}"
res = dictproxy.handle_dovecot_request(msg) res = dictproxy.handle_dovecot_request(msg, transactions)
assert res == "O\n" assert res == "O\n"
assert len(transactions) == 0 assert len(transactions) == 0
assert metadata.get_tokens_for_addr(testaddr) == [token] assert metadata.get_tokens_for_addr(testaddr) == [token]
# trigger notification for incoming message # trigger notification for incoming message
tx2 = "2222" tx2 = "2222"
assert dictproxy.handle_dovecot_request(f"B{tx2}\t{testaddr}") is None assert dictproxy.handle_dovecot_request(f"B{tx2}\t{testaddr}", transactions) is None
msg = f"S{tx2}\tpriv/guid00/messagenew" msg = f"S{tx2}\tpriv/guid00/messagenew"
assert dictproxy.handle_dovecot_request(msg) is None assert dictproxy.handle_dovecot_request(msg, transactions) is None
queue_item = notifier.retry_queues[0].get()[1] queue_item = notifier.retry_queues[0].get()[1]
assert queue_item.token == token assert queue_item.token == token
assert dictproxy.handle_dovecot_request(f"C{tx2}") == "O\n" assert dictproxy.handle_dovecot_request(f"C{tx2}", transactions) == "O\n"
assert not transactions assert not transactions
assert queue_item.path.exists() assert queue_item.path.exists()