Source code for qblox_instruments.pnp.main

# ----------------------------------------------------------------------------
# Description    : Plug & play/device discovery logic
# Git repository : https://gitlab.com/qblox/packages/software/qblox_instruments.git
# Copyright (C) Qblox BV (2021)
# ----------------------------------------------------------------------------


# -- include -----------------------------------------------------------------

import sys
import inspect
import pprint
import socket
import select
import time
import json
import ifaddr
import ipaddress
import uuid
import os
from typing import Tuple, Any, Union, Iterable, Dict


# -- definitions -------------------------------------------------------------

# The UDP port used for plug & play communication.
PNP_PORT = 20801


# -- class -------------------------------------------------------------------


[docs] class PlugAndPlay: """ Class that provides device discovery and IP address (re)configuration functionality, for instance to convert customer-controlled device names or serial numbers to IP addresses we can connect to via the usual interfaces. """ # ------------------------------------------------------------------------
[docs] def __init__(self): """ Creates a plug & play interface object. Use close() when you're done with the object, or a ``with`` clause:: with PlugAndPlay() as p: # do stuff with p here pass Parameters ---------- Raises ------ OSError If creating the network socket fails. """ super().__init__() # Iterate over all network adapters in the system, and create a # broadcast socket for all of them. self._socks = [] self._ips = [] adapters = ifaddr.get_adapters() for adapter in adapters: # Iterate over each IP address associated with the adapter. for ip in adapter.ips: try: # Ignore IPv6 addresses, since Qblox PNP is IPv4-only. if ip.is_IPv4: # Create the UDP socket. sock = socket.socket( socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP ) # Enable the broadcast flag, so we receive UDP # broadcast packets. sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) # Bind to the IP address of this interface and to the plug # and play port. sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((ip.ip, PNP_PORT)) # Store results. self._socks.append(sock) self._ips.append(f"{ip.ip}/{ip.network_prefix}") except OSError: # This will catch exception thrown by disabled network # adapters (on Windows). pass # Generate a sufficiently random number to use as initial sequence # number. self._seq = uuid.uuid4().int
# ------------------------------------------------------------------------
[docs] def close(self): """ Closes the underlying socket. The object must not be used anymore after this call. Parameters ---------- Returns ------- """ while self._socks: self._socks.pop().close() self._ips.pop()
# ------------------------------------------------------------------------ def __del__(self): self.close() def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() # ------------------------------------------------------------------------ def _send(self, cmd: str, serial_or_name: str = "") -> str: """ Broadcasts a command over the network to all available interfaces. Parameters ---------- cmd: str The command to send (second line onwards, i.e. excluding header). serial_or_name: str If specified, only the device with the given serial number or customer-given name should respond to the command. If not specified or an empty string, all devices will respond. Returns ------- str The exact sequence number string that was used for the command. Raises ------- OSError If enumerating interfaces or packet transmission fails. """ # Create a unique sequence number. self._seq += 1 seq = str(self._seq) # Format the message content and sequence number as bytestrings. msg = f"QBLOXPNPCMD:{seq}:{serial_or_name}\n{cmd}" # Broadcast the message on all sockets. for sock in self._socks: # Disable socket timeout, so we block if the OS is for some # reason not immediately ready to send the UDP packet. sock.settimeout(None) try: # Send to everything. sock.sendto(msg.encode("utf-8"), ("255.255.255.255", PNP_PORT)) except OSError: # This will catch exception thrown by sendmsg after bind # on localhost network adapter (on MacOS). # FIXME: understand why this isn't allowed on MacOS # https://gitlab.com/qblox/packages/software/qblox_instruments/-/issues/170 pass # Return the sequence number that we used. return seq # ------------------------------------------------------------------------ def _recv( self, seqs: Iterable[str], single: bool = False, timeout: float = 1.0 ) -> Dict[str, Dict[str, Tuple[str, int]]]: """ Waits for the reception of one or more responses to one or more commands. Parameters ---------- seqs: Iterable[str] The sequence number(s) of the commands that were sent, as returned by _send(). single: bool When set, only one response is expected for each command. This allows the function to terminate once a response has been received for all commands. If not set or unspecified, this method will always wait for the complete timeout duration, since it cannot know how many responses will arrive and when. timeout: float Timeout in seconds to wait for responses. Returns ------- Dict[str, Dict[str, Tuple[str, int]]] All responses received. The outer dict maps from command sequence number to the set of responses received for that command. The inner dict maps from device serial number to a tuple of the response received from that device and the index of the socket it was received on. Raises ------- OSError If enumerating interfaces or packet reception fails. """ # Turn the iterable of sequence numbers into a set. If first is set, # we'll remove sequence numbers from the set when we receive the first # response for that sequence number. We stop when we run out of # sequence numbers or the timeout expires. Note that packets received # with unknown sequence numbers are ignored, and that if first is set, # a sequence number is ignored after the first response is processed. seqs = set(seqs) # Create an empty response structure to fill. responses = {seq: {} for seq in seqs} # Determine when the timeout expires. stop = time.time() + timeout # While we have sequence numbers that we're sensitive to... while seqs: # Determine how many seconds remain until the timeout expires. remain = stop - time.time() # Break if the timeout has already expired. if remain < 0: break # Listen for packets until the timeout expires. Note that since # we're listening for any broadcast packet, we could be receiving # packets from pretty much any protocol here. That's why we impose # a size limit, and why we silently ignore packets that don't look # like valid responses. We could also receive PNP responses to # commands that another host sent, hence the UUID-based sequence # number check. ready, *_ = select.select(self._socks, [], [], remain) if not ready: break for sock in ready: sock.settimeout(None) response, address = sock.recvfrom(16384) index = self._socks.index(sock) # PNP packets are unicode-safe, so we can decode into a regular # string for convenience. try: response = response.decode("utf-8") except UnicodeError: continue # Split header (first line) from payload (second line onwards). response = response.split("\n", maxsplit=1) if len(response) < 2: continue recv_hdr, recv_result = response # Split header into its components. recv_hdr = recv_hdr.split(":") if len(recv_hdr) != 3: continue recv_magic, recv_seq, recv_serial = recv_hdr # The first part of the header must match the magic number for # PNP responses. if recv_magic != "QBLOXPNPREP": continue # The sequence number must be in our set. if recv_seq not in seqs: continue # Looks like our packet is valid, so record it. responses[recv_seq][recv_serial] = (recv_result, index) # If we're only expecting one packet per sequence number, # remove the sequence number from the sensitivity set. if single: seqs.remove(recv_seq) # Timeout expired or we received everything we were expecting to # receive. return responses # ------------------------------------------------------------------------ def _broadcast( self, cmd: str, count: int = 3, timeout: float = 1.0 ) -> Dict[str, Tuple[str, int]]: """ Combines a broadcast transmission with waiting for the corresponding responses. Parameters ---------- cmd: str The command to send. count: int Number of "retries": the command is sent out this number of times to reduce the odds of packet loss hindering device discovery. timeout: float Timeout in seconds to wait for responses. Returns ------- Dict[str, Tuple[str, int]] The response received and socket index it was received on for each serial number. Note that if multiple *different* responses are received for a single serial number, the resulting response in this set is arbitrary. It is thus important that serial numbers are actually unique in practice. Raises ------- OSError If transmission or reception fails. """ # Send the commands, gathering their sequence numbers. seqs = [self._send(cmd) for _ in range(count)] # Gather the responses. responses = self._recv(seqs, False, timeout) # Combine the responses from the various sequence numbers. combined = {} for response in responses.values(): combined.update(response) return combined # ------------------------------------------------------------------------ def _query( self, serial_or_name: str, cmd: str, retries: int = 3, timeout: float = 1.0 ) -> Tuple[str, int]: """ Combines a "unicast" transmission (to a single serial number or customer-given name) with waiting for the corresponding response. Parameters ---------- serial_or_name: str The serial number or customer-given name of the device we want to address. cmd: str The command to send. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- Tuple[str, int] The payload of the received response, and the index of the socket it was received on. Raises ------- OSError If transmission or reception fails. TimeoutError If there was no response from the device. """ # For each retry... for _ in range(retries): # Send the command. seq = self._send(cmd, serial_or_name) # Wait for and retrieve a single response. responses = self._recv([seq], True, timeout)[seq] # Return the response for the serial number we're interested in, # if any. if responses: return next(iter(responses.values())) # Ran out of retries. raise TimeoutError( f"No response from device with serial or name {serial_or_name}" ) # ------------------------------------------------------------------------ def _ask_confirmation(self, operation: str) -> None: """ Queries the user whether they want to continue. Parameters ---------- message: str Operation that is broadcast to all clusters Returns ------- None Raises ------ KeyboardInterrupt If the user cancelled the operation. """ print( f"This will broadcast '{operation}' to ALL clusters accessible from your PC, including **VIA WIFI OR VPN**." ) print( "The affected clusters include, BUT MAY NOT BE LIMITED TO, the following:" ) self.print_devices() print(f"Do you want to continue? Type 'Broadcast {operation}' to continue.") response = input() if response != f"Broadcast {operation}": raise KeyboardInterrupt( f"Aborting, input does not match 'Broadcast {operation}'." ) # ------------------------------------------------------------------------
[docs] def list_devices(self, timeout: float = 1.0) -> Dict[str, dict]: """ Lists all observable devices on the network. Parameters ---------- timeout: float Timeout in seconds to wait for responses. Returns ------- Dict[str, dict] Mapping from serial number to device description record as returned by the device. If a device returned an invalid structure, its dict will be {}. Raises ------- OSError If transmission or reception fails. """ # Broadcast an echo request to see which serial numbers are visible. serials = set(self._broadcast("ECHO", timeout=timeout)) # Send a describe command to each visible device in parallel. seqs = [self._send("DESCRIBE", serial) for serial in serials] responses = self._recv(seqs, True) # Combine the responses from the various sequence numbers; we can # distinguish by means of serial number. devices = {} for response in responses.values(): devices.update(response) # Parse the JSON structure for each device. for serial in devices: response, socket = devices[serial] try: devices[serial] = json.loads(response) except json.JSONDecodeError: devices[serial] = {} devices[serial]["connected_via"] = self._ips[socket] return devices
# ------------------------------------------------------------------------
[docs] def print_devices(self, timeout: float = 1.0) -> None: """ Like list_devices(), but prints a user-friendly device list instead of returning a data structure. Parameters ---------- timeout: float Timeout in seconds to wait for responses. Returns ------- Raises ------- OSError If transmission or reception fails. """ devices = self.list_devices(timeout) if not devices: print("No devices found") return print("Devices:") for serial, data in sorted( devices.items(), key=lambda x: x[1].get("identity", {}).get("ip", None) ): remote_ip = data.get("identity", {}).get("ip", None) local_ip_and_prefix = data.get("connected_via", None) if remote_ip is None: ip_info = "<unknown IP>" elif local_ip_and_prefix is None: ip_info = remote_ip else: local_net = ipaddress.IPv4Network(local_ip_and_prefix, strict=False) remote_net = ipaddress.IPv4Network(f"{remote_ip}/32", strict=False) if local_net.overlaps(remote_net): ip_info = remote_ip else: ip_info = f"{remote_ip} via {local_ip_and_prefix} (reconfiguration needed!)" print( f" - {ip_info}: {data.get('description', {}).get('model', '<unknown model>')} {'.'.join(map(str, data.get('description', {}).get('sw', {}).get('version', ['?'] * 3)))} with name \"{data.get('description', {}).get('name', '<unknown name>')}\" and serial number {serial}" )
# ------------------------------------------------------------------------
[docs] def identify( self, serial_or_name: str, retries: int = 3, timeout: float = 1.0 ) -> None: """ Visually identifies the device with the given serial number or customer-given name by having it blink its LEDs for a while. Parameters ---------- serial_or_name: str Serial number of the device that is to be identified. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- Raises ------- TypeError If serial_or_name is invalid. ValueError If serial_or_name is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. """ serial_or_name = str(serial_or_name) if not serial_or_name: raise ValueError("Serial number or name must not be blank") response, _ = self._query(serial_or_name, "IDENTIFY", retries, timeout) if response != "OK": raise RuntimeError(f"Unexpected response: {response}")
# ------------------------------------------------------------------------
[docs] def identify_all(self, count: int = 3) -> None: """ Instructs all devices visible on the network to blink their LEDs. Parameters ---------- count: int Number of times to repeat the command packet, to reduce the odds of packet loss being a problem. Returns ------- Raises ------- OSError If transmission or reception fails. """ for _ in range(count): self._send("IDENTIFY")
# ------------------------------------------------------------------------
[docs] def describe( self, serial_or_name: str, retries: int = 3, timeout: float = 1.0 ) -> dict: """ Returns the device description structure corresponding to the device with the given serial number or customer-given name. Parameters ---------- serial_or_name: str Serial number or customer-given name of the device that is to be queried. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- dict The device description structure. Raises ------- TypeError If serial_or_name is invalid. ValueError If serial_or_name is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. """ response, socket = self._query(serial_or_name, "DESCRIBE", retries, timeout) try: response = json.loads(response) response["connected_via"] = self._ips[socket] return response except json.JSONDecodeError: raise RuntimeError(f"Unexpected response: {response}")
# ------------------------------------------------------------------------
[docs] def get_serial(self, name: str, retries: int = 3, timeout: float = 1.0) -> str: """ Returns the serial number of the device with the given customer-given name. Parameters ---------- name: str Customer-given name of the device that is to be queried. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- str The serial number of the device. Raises ------- TypeError If name is invalid. ValueError If name is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. KeyError If the device response did not contain the requested information. """ return self.describe(name, retries, timeout)["description"]["ser"]
# ------------------------------------------------------------------------
[docs] def get_name(self, serial: str, retries: int = 3, timeout: float = 1.0) -> str: """ Returns the customer-given name of the device with the given serial number. Parameters ---------- serial: str Serial number of the device that is to be queried. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- str The customer-given name of the device. Raises ------- TypeError If serial is invalid. ValueError If serial is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. KeyError If the device response did not contain the requested information. """ return self.describe(serial, retries, timeout)["description"]["name"]
# ------------------------------------------------------------------------
[docs] def set_name( self, serial_or_name: str, new_name: str, retries: int = 3, timeout: float = 1.0 ) -> None: """ Renames the device with the given serial number or name. Parameters ---------- serial_or_name: str Serial number or customer-given name of the device that is to be reconfigured. new_name: str The new customer-given name for the device. May not contain newlines, double quotes, or backslashes. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- Raises ------- TypeError If serial_or_name or new_name are invalid. ValueError If serial_or_name or new_name are invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. """ serial_or_name = str(serial_or_name) if not serial_or_name: raise ValueError("Serial number or name must not be blank") new_name = str(new_name) if not new_name: raise ValueError("New name must not be blank") if "\n" in new_name: raise ValueError("Device name may not include newlines") if "\\" in new_name: raise ValueError("Device name may not include backslashes") if '"' in new_name: raise ValueError("Device name may not include double quotes") response, _ = self._query( serial_or_name, f"SET_NAME {new_name}", retries, timeout ) if response != "OK": raise RuntimeError(f"Unexpected response: {response}")
# ------------------------------------------------------------------------
[docs] def get_ip( self, serial_or_name: str, retries: int = 3, timeout: float = 1.0 ) -> str: """ Returns the IP address of the device with the given serial number or customer-given name. Parameters ---------- serial_or_name: str Serial number or customer-given name of the device that is to be queried. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- str The IP address of the device. Raises ------- TypeError If serial_or_name is invalid. ValueError If serial_or_name is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. KeyError If the device response did not contain the requested information. """ return self.describe(serial_or_name, retries, timeout)["identity"]["ip"]
# ------------------------------------------------------------------------
[docs] def set_ip( self, serial_or_name: str, ip_address: str, retries: int = 3, timeout: float = 1.0, ) -> None: """ Adjusts the IP address configuration of the device with the given serial number or customer-given name. The device will reboot as a result of this. Parameters ---------- serial_or_name: str Serial number or customer-given name of the device that is to be reconfigured. ip_address: str The new IP address configuration for the device. This may be an IPv4 address including prefix length (`x.x.x.x/x`), an IPv6 address including prefix length (e.g. `x:x::x:x/x`), a combination thereof separated via a semicolon, or the string `dhcp` to have the device obtain an IPv4 address via DHCP. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- Raises ------- TypeError If serial_or_name is invalid. ValueError If serial_or_name is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. """ serial_or_name = str(serial_or_name) if not serial_or_name: raise ValueError("Serial number or name must not be blank") response, _ = self._query( serial_or_name, f"SET_IP {ip_address}", retries, timeout ) if response != "OK": raise RuntimeError(f"Unexpected response: {response}")
# ------------------------------------------------------------------------
[docs] def set_all_dhcp(self, count: int = 3) -> None: """ Instructs all devices on the network to reboot and obtain an IP address via DHCP. Parameters ---------- count: int Number of times to repeat the command packet, to reduce the odds of packet loss being a problem. Returns ------- Raises ------- OSError If transmission or reception fails. """ self._ask_confirmation("set-dhcp") for _ in range(count): self._send("SET_IP dhcp")
# ------------------------------------------------------------------------
[docs] def reboot( self, serial_or_name: str, retries: int = 3, timeout: float = 1.0 ) -> None: """ Reboots the device with the given serial number or customer-given name. Parameters ---------- serial_or_name: str Serial number of the device that is to be rebooted. retries: int Number of times to retry sending the command, if no response is received. timeout: float Timeout in seconds to wait for a response, per retry. Returns ------- Raises ------- TypeError If serial_or_name is invalid. ValueError If serial_or_name is invalid. OSError If transmission or reception fails. TimeoutError If no respone is received from the device. RuntimeError If an unexpected response is received from the device. """ serial_or_name = str(serial_or_name) if not serial_or_name: raise ValueError("Serial number or name must not be blank") response, _ = self._query(serial_or_name, "REBOOT", retries, timeout) if response != "OK": raise RuntimeError(f"Unexpected response: {response}")
# ------------------------------------------------------------------------
[docs] def reboot_all(self, count: int = 3) -> None: """ Instructs all devices on the network to reboot. Parameters ---------- count: int Number of times to repeat the command packet, to reduce the odds of packet loss being a problem. Returns ------- Raises ------- OSError If transmission or reception fails. """ self._ask_confirmation("reboot") for _ in range(count): self._send("REBOOT")
# ------------------------------------------------------------------------
[docs] def recover_device(self) -> None: """ Attempts to recover a device with a severely broken IP configuration, by instructing ALL devices on the network to revert back to 192.168.0.2/24. ONLY RUN THIS COMMAND WHEN YOU ARE ONLY CONNECTED TO A SINGLE DEVICE, OR YOU WILL GET IP ADDRESS CONFLICTS. Parameters ---------- Returns ------- Raises ------- OSError If recovery packet transmission fails. """ self._ask_confirmation("reset-ip") for _ in range(10): self._send("SET_IP 192.168.0.2/24") time.sleep(0.1)
# ------------------------------------------------------------------------
[docs] @staticmethod def cmd_line(*args: Iterable[str]) -> Any: """ Runs the plug & play command-line tool with the given arguments. Parameters ---------- *args: Iterable[str] The command-line arguments. Returns ------- Any If the given command logically returns something, it will be returned as a Python value in addition to being printed as a string. Otherwise, None will be returned. Raises ------ RuntimeError If the command-line tool returns a nonzero exit status. """ # This version is intended to be called from within a script or # notebook, so catch sys.exit() calls, and enable tracebacks. code = 0 try: _main(*args) except SystemExit as e: code = e.code if code != 0: raise RuntimeError(f"exit with status {code}")
# -- command-line tool ------------------------------------------------------- def _main(args: Union[None, Iterable[str]] = None) -> Any: """ Runs the Qblox plug & play tool. Parameters ---------- args: Union[None, Iterable[str]] When None, this will run the plug & play tool as if called from the command line. Arguments are taken from ``sys.argv``, and ``sys.exit()`` is called when complete. When this is an iterable of strings, these strings are interpreted as ``sys.argv[1:]``, and exceptions will never be caught. Returns ------- Any If run from a script and the command logically returns something, it will be returned as a Python value in addition to being printed as a string. If run from the command line, this always calls sys.exit(). """ with PlugAndPlay() as p: # Definitions. VERSION = "0.0.1" HELP_TEXT = f"Qblox plug & play version {VERSION}\n\nThis program allows you to scan your LAN for instruments from Qblox.\nA list of available commands follows. Run `{os.path.basename(sys.argv[0])} <command> help` for\nmore information about a command.\n" HELP_LIKE = ("help", "-h", "--help", "/?") # Define commands available on the command line. The code after this # is pretty much just boilerplate to cmds = { "list": p.print_devices, "describe": p.describe, "identify": p.identify, "reboot": p.reboot, "get": { "name": p.get_name, "ip": p.get_ip, "serial": p.get_serial, "json": p.describe, }, "set": { "name": p.set_name, "ip": p.set_ip, }, "all": { "describe": p.list_devices, "identify": p.identify_all, "reboot": p.reboot_all, "dhcp": p.set_all_dhcp, }, "recover-device": p.recover_device, } def print_cmds_and_exit(arg_stack, cmds): """ Prints the list of (sub)commands in the given cmds structure, assumed to be (sub)commands of the commands listed in arg_stack. """ if arg_stack: print( f"Available subcommands of {' '.join(arg_stack)}:", file=sys.stderr, ) else: print("Available commands:", file=sys.stderr) def recurse(cmds, prefix=" -"): for name, action in cmds.items(): if isinstance(action, dict): recurse(action, f"{prefix} {name}") else: print(f"{prefix} {name}", file=sys.stderr) recurse(cmds) sys.exit(2) def print_help_and_exit(arg_stack, fn): """ Prints the call signature and docstring of the given function, going by the command-line command name given by arg_stack. """ params = arg_stack for param in inspect.signature(fn).parameters.values(): if param.default is inspect.Parameter.empty: params.append(f"<{param.name}>") else: params.append(f"[{param.name}={param.default}]") print( f"Syntax: {sys.argv[0]} {' '.join(arg_stack)}\n\n{inspect.cleandoc(fn.__doc__)}", file=sys.stderr, ) sys.exit(2) # Print help if no arguments are given or a help-like argument is # given. if args is None: args = sys.argv[1:] from_script = False else: from_script = True if not args or args[0] in HELP_LIKE: print(HELP_TEXT, file=sys.stderr) print_cmds_and_exit([], cmds) # Determine which command is being run. arg_stack = [] while isinstance(cmds, dict): if not args: print_cmds_and_exit(arg_stack, cmds) arg, *args = args subcmds = cmds.get(arg, None) if subcmds is None: print(f"Error: unknown command '{arg}'\n", file=sys.stderr) print_cmds_and_exit(arg_stack, cmds) cmds = subcmds arg_stack.append(arg) fn = cmds # Print help for a specific command if requested. if len(args) > 0 and args[0] in HELP_LIKE: print_help_and_exit(arg_stack, fn) # Convert the remaining command-line arguments to the types expected # by the command function. fn_args = [] signature = inspect.signature(fn) for param in signature.parameters.values(): if not args: if param.default is inspect.Parameter.empty: print_help_and_exit(arg_stack, fn) break arg, *args = args try: if param.annotation is str: fn_args.append(arg) elif param.annotation is int: fn_args.append(int(arg)) elif param.annotation is float: fn_args.append(float(arg)) else: raise TypeError(f"unsupported annotation {param.annotation}") except (ValueError, TypeError): print( f"Error: invalid value for '{param.name}'\n", file=sys.stderr, ) print_help_and_exit(arg_stack, fn) if args: print("Error: too many arguments\n", file=sys.stderr) print_help_and_exit(arg_stack, fn) # Actually run the command. try: result = fn(*fn_args) except Exception as e: # Unsuccessful run. if from_script: raise else: print(f"{type(e).__name__}: {e}", file=sys.stderr) sys.exit(1) # If the command is supposed to return something, pretty-print it. if signature.return_annotation is not None: pprint.pprint(result) # Successful run. if from_script: return result else: sys.exit(0) # - main --------------------------------------------------------------------- if __name__ == "__main__": _main()