diff --git a/cmdeploy/src/cmdeploy/cmdeploy.py b/cmdeploy/src/cmdeploy/cmdeploy.py index 59187734..057974ed 100644 --- a/cmdeploy/src/cmdeploy/cmdeploy.py +++ b/cmdeploy/src/cmdeploy/cmdeploy.py @@ -19,7 +19,7 @@ from packaging import version from termcolor import colored from . import dns, remote -from .sshexec import SSHExec +from .sshexec import SSHExec, Local # # cmdeploy sub commands and options @@ -62,13 +62,18 @@ def run_cmd_options(parser): "--ssh-host", dest="ssh_host", help="specify an SSH host to deploy to; uses mail_domain from chatmail.ini by default", + default=None, ) def run_cmd(args, out): """Deploy chatmail services on the remote server.""" - sshexec = args.get_sshexec() + ssh_host = args.ssh_host if args.ssh_host else args.config.mail_domain + if ssh_host == "localhost": + sshexec = Local(ssh_host) + else: + sshexec = args.get_sshexec(ssh_host) require_iroh = args.config.enable_iroh_relay remote_data = dns.get_initial_remote_data(sshexec, args.config.mail_domain) if not dns.check_initial_remote_data(remote_data, print=out.red): @@ -80,7 +85,7 @@ def run_cmd(args, out): env["CHATMAIL_REQUIRE_IROH"] = "True" if require_iroh else "" deploy_path = importlib.resources.files(__package__).joinpath("deploy.py").resolve() pyinf = "pyinfra --dry" if args.dry_run else "pyinfra" - ssh_host = args.config.mail_domain if not args.ssh_host else args.ssh_host + ssh_host = "@local" if ssh_host == "localhost" else f"--ssh-host {ssh_host}" cmd = f"{pyinf} --ssh-user root {ssh_host} {deploy_path} -y" if version.parse(pyinfra.__version__) < version.parse("3"): out.red("Please re-run scripts/initenv.sh to update pyinfra to version 3.") @@ -330,9 +335,9 @@ def main(args=None): if not hasattr(args, "func"): return parser.parse_args(["-h"]) - def get_sshexec(): - print(f"[ssh] login to {args.config.mail_domain}") - return SSHExec(args.config.mail_domain, verbose=args.verbose) + def get_sshexec(host): + print(f"[ssh] login to {host}") + return SSHExec(host, verbose=args.verbose) args.get_sshexec = get_sshexec diff --git a/cmdeploy/src/cmdeploy/sshexec.py b/cmdeploy/src/cmdeploy/sshexec.py index 8a87e781..b5364028 100644 --- a/cmdeploy/src/cmdeploy/sshexec.py +++ b/cmdeploy/src/cmdeploy/sshexec.py @@ -1,5 +1,6 @@ import inspect import os +import subprocess import sys from queue import Queue @@ -44,30 +45,16 @@ def print_stderr(item="", end="\n"): print(item, file=sys.stderr, end=end) -class SSHExec: - RemoteError = execnet.RemoteError +class Exec: FuncError = FuncError - def __init__(self, host, verbose=False, python="python3", timeout=60): - self.gateway = execnet.makegateway(f"ssh=root@{host}//python={python}") - self._remote_cmdloop_channel = bootstrap_remote(self.gateway, remote) + def __init__(self, host, verbose, timeout): + self.host = host self.timeout = timeout self.verbose = verbose def __call__(self, call, kwargs=None, log_callback=None): - if kwargs is None: - kwargs = {} - assert call.__module__.startswith("cmdeploy.remote") - modname = call.__module__.replace("cmdeploy.", "") - self._remote_cmdloop_channel.send((modname, call.__name__, kwargs)) - while 1: - code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout) - if log_callback is not None and code == "log": - log_callback(data) - elif code == "finish": - return data - elif code == "error": - raise self.FuncError(data) + return subprocess.check_output(call) def logged(self, call, kwargs): def log_progress(data): @@ -85,3 +72,33 @@ class SSHExec: res = self(call, kwargs, log_callback=log_progress) print_stderr() return res + + +class Local(Exec): + + def __init__(self, host, verbose=False, timeout=60): + super().__init__(host, verbose, timeout) + + +class SSHExec(Exec): + RemoteError = execnet.RemoteError + + def __init__(self, host, verbose=False, timeout=60): + super().__init__(host, verbose, timeout) + self.gateway = execnet.makegateway(f"ssh=root@{host}//python=python3") + self._remote_cmdloop_channel = bootstrap_remote(self.gateway, remote) + + def __call__(self, call, kwargs=None, log_callback=None): + if kwargs is None: + kwargs = {} + assert call.__module__.startswith("cmdeploy.remote") + modname = call.__module__.replace("cmdeploy.", "") + self._remote_cmdloop_channel.send((modname, call.__name__, kwargs)) + while 1: + code, data = self._remote_cmdloop_channel.receive(timeout=self.timeout) + if log_callback is not None and code == "log": + log_callback(data) + elif code == "finish": + return data + elif code == "error": + raise self.FuncError(data)