From 2c7f989f29dbbbeb469b60a2b4c0f4357e1d8b14 Mon Sep 17 00:00:00 2001 From: AL-LCL Date: Fri, 19 May 2023 11:18:33 +0200 Subject: PY-CHAT --- protocol.py | 413 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 protocol.py (limited to 'protocol.py') diff --git a/protocol.py b/protocol.py new file mode 100644 index 0000000..55b8f72 --- /dev/null +++ b/protocol.py @@ -0,0 +1,413 @@ +# Modified Source From: https://git.alvinhavel.com/PY-NET/plain/shared.py + +import ipaddress +import platform +import base64 +import socket +# import uuid +import zlib +import json +import ssl +import re + +from typing import Callable +from typing import Optional +from typing import Union +from typing import Tuple +from typing import Dict +from typing import Any + +DEFAULT_HOSTNAME = 'localhost' +DEFAULT_PORT = 38567 # 38568 +DEFAULT_ENCODING = 'utf-8' +DEFAULT_LANGUAGE_CODE = 65001 +STRICT_ENCODING_ERRORS = 'strict' +LIBERAL_ENCODING_ERRORS = 'replace' + +class Platform: + + PLATFORM = platform.system() + WINDOWS = PLATFORM == 'Windows' + LINUX = PLATFORM == 'Linux' + MAC = PLATFORM == 'Mac' + UNIX = LINUX or MAC + +class ExtendedEncoder(json.JSONEncoder): + + def default(self, obj: Any) -> Any: + if isinstance(obj, bytes): + return base64.b64encode(obj).decode(DEFAULT_ENCODING, + errors=STRICT_ENCODING_ERRORS) + + return json.JSONEncoder.default(self, obj) + +class ExtendedDecoder(json.JSONDecoder): + + def default(self, obj: Any) -> Any: + if isinstance(obj, bytes): + return base64.b64decode(obj.encode(DEFAULT_ENCODING, + errors=STRICT_ENCODING_ERRORS)) + + return json.JSONDecoder.default(self, obj) + +class Bytes: + + LABELS = ('B', 'kB', 'MB', 'GB') + LAST_LABEL = LABELS[-1] + UNIT_STEP = 1024 + UNIT_STEP_THRESH = UNIT_STEP - 0.005 + + @staticmethod + def format(num: Union[int, float]) -> str: + assert isinstance(num, (int, float)), f'Wrong type: {num=}' + assert num >= 0, f'Wrong value: {num=}' + + for unit in Bytes.LABELS: + if num < Bytes.UNIT_STEP_THRESH: + break + + if unit != Bytes.LAST_LABEL: + num /= Bytes.UNIT_STEP + + return f'{num:.2f} {unit}' + +class Socket: + + def __init__( + self, + hostname: str, + port: int, + *, + conn: Optional[Union[socket.socket, ssl.SSLSocket]]=None, + server_side: bool=False, + is_host: bool=False + ) -> None: + assert isinstance(hostname, str), f'Wrong type: {hostname=}' + assert isinstance(port, int), f'Wrong type: {port=}' + assert port >= 1024 and port <= 65535, f'Wrong value: {port=}' + assert isinstance(conn, (socket.socket, ssl.SSLSocket)) or conn is None, f'Wrong type: {conn=}' + assert isinstance(server_side, bool), f'Wrong type: {server_side=}' + assert isinstance(is_host, bool), f'Wrong type: {is_host=}' + + dns_resolved_ip = socket.gethostbyname(hostname) + self.ip = ipaddress.ip_address(dns_resolved_ip) + self.port = port + self.conn = conn + self.server_side = server_side + self.is_host = is_host + + if self.server_side: + self.hostname = hostname if hostname != dns_resolved_ip else None + self.data_wrap = None + self.data_wrap_notes = None + + if self.is_host: + self.in_session = None + self.ENCODING = None + else: + self.in_session = False + else: + assert not self.is_host, f'Wrong value: {self.is_host=}' + + if not self.is_host: + UUID_REGEX = r'^[\da-f]{8}-([\da-f]{4}-){3}[\da-f]{12}$' + self.UUID_PATTERN = re.compile(UUID_REGEX, re.IGNORECASE) + self.ENCODING = DEFAULT_ENCODING + self.DEFAULT_BUFFER_SIZE = self.BUFFER_SIZE = 1024 + self.MAX_BUFFER_SIZE = 65536 + self.SEND_HEADER_SIZE = 10 + self.RECV_HEADER_SIZE = 10 # 46 + + self.header_send_callback = None + self.header_recv_callback = None + self.send_callback = None + self.recv_callback = None + # self.uuid_token = None + + def __enter__(self) -> Union[socket.socket, ssl.SSLSocket]: + assert self.conn, f'Missing attribute: {self.conn=}' + return self.conn + + def __exit__(self, *_) -> None: + self.close() + + def __str__(self) -> str: + if self.server_side: + return f'tcp://{self.ip}:{self.port}' + else: + return super().__str__() + + def address(self) -> Tuple[str, ...]: + assert self.server_side, f'Wrong value: {self.server_side=}' + return (self.data_wrap, self.in_session, str(self), self.hostname, self.ENCODING, self.data_wrap_notes) + + def address_headers(self) -> Tuple[str, ...]: + assert self.server_side, f'Wrong value: {self.server_side=}' + return ('ID', 'Type', 'In Session', 'Address', 'Hostname', 'Encoding', 'Type Notes') + + def detailed_address(self) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[str, str]]: + assert self.server_side, f'Wrong value: {self.server_side=}' + return ((('Type', self.data_wrap), + ('In Session', self.in_session), + ('Address', str(self)), + ('Hostname', self.hostname), + ('Encoding', self.ENCODING), + ('Type Notes', self.data_wrap_notes), + ('Address Type', f'{self.ip.max_prefixlen}-bit IPv{self.ip.version}'), + ('Reverse Pointer', self.ip.reverse_pointer), + ('Global Address', self.ip.is_global), + ('Link Local Address', self.ip.is_link_local), + ('Loopback Address', self.ip.is_loopback), + ('Multicast Address', self.ip.is_multicast), + ('Private Address', self.ip.is_private), + ('Reserved Address', self.ip.is_reserved), + ('Unspecified Address', self.ip.is_unspecified)), + ('Key', 'Value')) + + def close(self) -> None: + assert self.conn, f'Missing attribute: {self.conn=}' + + try: + self.conn.shutdown(socket.SHUT_RDWR) + except Exception: + pass + finally: + self.conn.close() + + def set_conn(self, **kwargs) -> socket.socket: + if self.server_side: + assert self.is_host, f'Wrong value: {self.is_host=}' + + self.conn = socket.create_server((str(self.ip), self.port), **kwargs) + else: + self.conn = socket.create_connection((str(self.ip), self.port), **kwargs) + + def set_context(self) -> None: + assert self.server_side, f'Wrong value: {self.server_side=}' + + if self.is_host: + self.data_wrap = 'HOSTING :: COMP' + else: + self.data_wrap = 'CONNECTING :: COMP' + + def set_middleware(self) -> None: + assert not self.is_host, f'Wrong value: {self.is_host=}' + + self.send_callback = lambda body: zlib.compress(body) + self.recv_callback = lambda body: zlib.decompress(body) + + def send(self, obj: Dict[str, Any], callback: Callable[[str, str], None]=None) -> None: + assert not self.is_host, f'Wrong value: {self.is_host=}' + assert self.conn, f'Missing attribute: {self.conn=}' + assert isinstance(obj, dict), f'Wrong type: {obj=}' + assert callable(callback) or callback is None, f'Wrong type: {callback=}' + + body = json.dumps(obj, cls=ExtendedEncoder) + body = body.encode(self.ENCODING, errors=LIBERAL_ENCODING_ERRORS) + + if self.send_callback: + body = self.send_callback(body) + + # if self.server_side: + # self.uuid_token = str(uuid.uuid4()) + + body_size = len(body) + header = str(body_size).ljust(self.SEND_HEADER_SIZE) + header = header.encode(DEFAULT_ENCODING, errors=STRICT_ENCODING_ERRORS) + # header = (header + self.uuid_token).encode(DEFAULT_ENCODING, + # errors=STRICT_ENCODING_ERRORS) + + if self.header_send_callback: + header = self.header_send_callback(header) + + if callback: + header_size = len(header) + callback(Bytes.format(header_size), + Bytes.format(body_size), + Bytes.format(header_size + body_size)) + + self.conn.sendall(header + body) + + def recv(self, callback: Callable[[str, str], None]=None) -> Dict[str, Any]: + assert not self.is_host, f'Wrong value: {self.is_host=}' + assert self.conn, f'Missing attribute: {self.conn=}' + assert callable(callback) or callback is None, f'Wrong type: {callback=}' + + position = 0 + + if callback: + history = [] + + if self.BUFFER_SIZE != self.DEFAULT_BUFFER_SIZE: + self.BUFFER_SIZE = self.DEFAULT_BUFFER_SIZE + + while True: + buffer = self.conn.recv(self.BUFFER_SIZE) + + assert buffer, f'Missing attribute: {buffer=}' + + buffer_size = len(buffer) + + if position > 0: + body[position:position + buffer_size] = buffer + position += buffer_size + else: + header = buffer[:self.RECV_HEADER_SIZE] + + if self.header_recv_callback: + header = self.header_recv_callback(header) + + assert isinstance(header, bytes), f'Wrong type: {header=}' + + header = header.decode(DEFAULT_ENCODING, errors=STRICT_ENCODING_ERRORS) + # uuid_token = header[self.SEND_HEADER_SIZE:self.RECV_HEADER_SIZE] + + # assert self.UUID_PATTERN.match(uuid_token), f'Wrong value: {uuid_token=}' + + # if self.server_side: + # assert uuid_token == self.uuid_token, f'Wrong value: {uuid_token} != {self.uuid_token}' + # else: + # self.uuid_token = uuid_token + + body_size = int(header[:self.SEND_HEADER_SIZE]) + + if body_size >= self.MAX_BUFFER_SIZE: + self.BUFFER_SIZE = self.MAX_BUFFER_SIZE + + if callback: + bytes_to_recieve = Bytes.format(body_size + self.RECV_HEADER_SIZE) + + body_buffer_size = buffer_size - self.RECV_HEADER_SIZE + body = bytearray(body_size) + body[:buffer_size] = buffer[self.RECV_HEADER_SIZE:] + position += body_buffer_size + + assert position <= body_size, f'Wrong value: {position} > {body_size}' + + if callback: + history.append((Bytes.format(buffer_size), + Bytes.format(position + self.RECV_HEADER_SIZE), + bytes_to_recieve)) + + if position == body_size: + body = bytes(body) + + if self.recv_callback: + body = self.recv_callback(body) + + assert isinstance(body, bytes), f'Wrong type: {body=}' + + body = body.decode(self.ENCODING, errors=LIBERAL_ENCODING_ERRORS) + body = json.loads(body, cls=ExtendedDecoder) + + assert isinstance(body, dict), f'Wrong type: {body=}' + + if callback: + callback(history) + + return body + +class SymmetricSocket(Socket): + + def __init__( + self, + *args, + password: str=None, + salt: str=None, + **kwargs + ) -> None: + super().__init__(*args, **kwargs) + + assert isinstance(password, str) or password is None, f'Wrong type: {password=}' + assert isinstance(salt, str) or salt is None, f'Wrong type: {salt=}' + + self.password = password + self.salt = salt + + def set_context(self) -> None: + assert self.server_side, f'Wrong value: {self.server_side=}' + + if self.is_host: + self.data_wrap = 'HOSTING :: AES' + self.data_wrap_notes = f'{self.password} :: {self.salt}' + else: + self.data_wrap = 'CONNECTING :: AES' + self.data_wrap_notes = 'AES_128_CBC_PKCS7_HMAC_SHA256' + + def set_middleware(self) -> None: + assert not self.is_host, f'Wrong value: {self.is_host=}' + assert self.password, f'Missing attribute: {self.password=}' + assert self.salt, f'Missing attribute: {self.salt=}' + + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.fernet import Fernet + + secret = PBKDF2HMAC(algorithm=hashes.SHA256(), + length=32, + salt=self.salt.encode(DEFAULT_ENCODING, + errors=STRICT_ENCODING_ERRORS), + iterations=320000, + backend=default_backend()) + secret = secret.derive(self.password.encode(DEFAULT_ENCODING, + errors=STRICT_ENCODING_ERRORS)) + secret = Fernet(base64.urlsafe_b64encode(secret)) + + recv_header_size = ''.ljust(self.RECV_HEADER_SIZE) + recv_header_size = recv_header_size.encode(DEFAULT_ENCODING, + errors=STRICT_ENCODING_ERRORS) + self.RECV_HEADER_SIZE = len(secret.encrypt(recv_header_size)) + self.header_send_callback = lambda header: secret.encrypt(header) + self.header_recv_callback = lambda header: secret.decrypt(header) + self.send_callback = lambda body: secret.encrypt(zlib.compress(body)) + self.recv_callback = lambda body: zlib.decompress(secret.decrypt(body)) + +class AsymmetricSocket(Socket): + + def __init__( + self, + *args, + public_key: str=None, + private_key: str=None, + public_key_data: str=None, + **kwargs + ) -> None: + super().__init__(*args, **kwargs) + + assert isinstance(public_key, str) or public_key is None, f'Wrong type: {public_key=}' + assert isinstance(private_key, str) or private_key is None, f'Wrong type: {private_key=}' + assert isinstance(public_key_data, str) or public_key_data is None, f'Wrong type: {public_key_data=}' + + self.public_key = public_key + self.private_key = private_key + self.public_key_data = public_key_data + + def set_context(self) -> None: + assert self.server_side, f'Wrong value: {self.server_side=}' + + if self.is_host: + self.data_wrap = 'HOSTING :: TLS' + self.data_wrap_notes = f'{self.public_key} :: {self.private_key}' + else: + self.data_wrap = 'CONNECTING :: TLS' + cipher = self.conn.cipher() + + if cipher: + self.data_wrap_notes = ' :: '.join(cipher[:2]) + + def set_middleware(self) -> None: + if self.server_side: + assert self.is_host, f'Wrong value: {self.is_host=}' + assert self.public_key, f'Missing attribute: {self.public_key=}' + assert self.private_key, f'Missing attribute: {self.private_key=}' + + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.load_cert_chain(self.public_key, self.private_key) + self.conn = context.wrap_socket(self.conn, server_side=True) + else: + assert self.public_key_data, f'Missing attribute: {self.public_key_data=}' + + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.load_verify_locations(cadata=self.public_key_data) + self.conn = context.wrap_socket(self.conn) -- cgit v1.2.3