summaryrefslogtreecommitdiff
path: root/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'protocol.py')
-rw-r--r--protocol.py413
1 files changed, 413 insertions, 0 deletions
diff --git a/protocol.py b/protocol.py
new file mode 100644
index 0000000..55b8f72
--- /dev/null
+++ b/protocol.py
@@ -0,0 +1,413 @@
+# Modified Source From: https://git.alvinhavel.com/PY-NET/plain/shared.py
+
+import ipaddress
+import platform
+import base64
+import socket
+# import uuid
+import zlib
+import json
+import ssl
+import re
+
+from typing import Callable
+from typing import Optional
+from typing import Union
+from typing import Tuple
+from typing import Dict
+from typing import Any
+
+DEFAULT_HOSTNAME = 'localhost'
+DEFAULT_PORT = 38567 # 38568
+DEFAULT_ENCODING = 'utf-8'
+DEFAULT_LANGUAGE_CODE = 65001
+STRICT_ENCODING_ERRORS = 'strict'
+LIBERAL_ENCODING_ERRORS = 'replace'
+
+class Platform:
+
+ PLATFORM = platform.system()
+ WINDOWS = PLATFORM == 'Windows'
+ LINUX = PLATFORM == 'Linux'
+ MAC = PLATFORM == 'Mac'
+ UNIX = LINUX or MAC
+
+class ExtendedEncoder(json.JSONEncoder):
+
+ def default(self, obj: Any) -> Any:
+ if isinstance(obj, bytes):
+ return base64.b64encode(obj).decode(DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS)
+
+ return json.JSONEncoder.default(self, obj)
+
+class ExtendedDecoder(json.JSONDecoder):
+
+ def default(self, obj: Any) -> Any:
+ if isinstance(obj, bytes):
+ return base64.b64decode(obj.encode(DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS))
+
+ return json.JSONDecoder.default(self, obj)
+
+class Bytes:
+
+ LABELS = ('B', 'kB', 'MB', 'GB')
+ LAST_LABEL = LABELS[-1]
+ UNIT_STEP = 1024
+ UNIT_STEP_THRESH = UNIT_STEP - 0.005
+
+ @staticmethod
+ def format(num: Union[int, float]) -> str:
+ assert isinstance(num, (int, float)), f'Wrong type: {num=}'
+ assert num >= 0, f'Wrong value: {num=}'
+
+ for unit in Bytes.LABELS:
+ if num < Bytes.UNIT_STEP_THRESH:
+ break
+
+ if unit != Bytes.LAST_LABEL:
+ num /= Bytes.UNIT_STEP
+
+ return f'{num:.2f} {unit}'
+
+class Socket:
+
+ def __init__(
+ self,
+ hostname: str,
+ port: int,
+ *,
+ conn: Optional[Union[socket.socket, ssl.SSLSocket]]=None,
+ server_side: bool=False,
+ is_host: bool=False
+ ) -> None:
+ assert isinstance(hostname, str), f'Wrong type: {hostname=}'
+ assert isinstance(port, int), f'Wrong type: {port=}'
+ assert port >= 1024 and port <= 65535, f'Wrong value: {port=}'
+ assert isinstance(conn, (socket.socket, ssl.SSLSocket)) or conn is None, f'Wrong type: {conn=}'
+ assert isinstance(server_side, bool), f'Wrong type: {server_side=}'
+ assert isinstance(is_host, bool), f'Wrong type: {is_host=}'
+
+ dns_resolved_ip = socket.gethostbyname(hostname)
+ self.ip = ipaddress.ip_address(dns_resolved_ip)
+ self.port = port
+ self.conn = conn
+ self.server_side = server_side
+ self.is_host = is_host
+
+ if self.server_side:
+ self.hostname = hostname if hostname != dns_resolved_ip else None
+ self.data_wrap = None
+ self.data_wrap_notes = None
+
+ if self.is_host:
+ self.in_session = None
+ self.ENCODING = None
+ else:
+ self.in_session = False
+ else:
+ assert not self.is_host, f'Wrong value: {self.is_host=}'
+
+ if not self.is_host:
+ UUID_REGEX = r'^[\da-f]{8}-([\da-f]{4}-){3}[\da-f]{12}$'
+ self.UUID_PATTERN = re.compile(UUID_REGEX, re.IGNORECASE)
+ self.ENCODING = DEFAULT_ENCODING
+ self.DEFAULT_BUFFER_SIZE = self.BUFFER_SIZE = 1024
+ self.MAX_BUFFER_SIZE = 65536
+ self.SEND_HEADER_SIZE = 10
+ self.RECV_HEADER_SIZE = 10 # 46
+
+ self.header_send_callback = None
+ self.header_recv_callback = None
+ self.send_callback = None
+ self.recv_callback = None
+ # self.uuid_token = None
+
+ def __enter__(self) -> Union[socket.socket, ssl.SSLSocket]:
+ assert self.conn, f'Missing attribute: {self.conn=}'
+ return self.conn
+
+ def __exit__(self, *_) -> None:
+ self.close()
+
+ def __str__(self) -> str:
+ if self.server_side:
+ return f'tcp://{self.ip}:{self.port}'
+ else:
+ return super().__str__()
+
+ def address(self) -> Tuple[str, ...]:
+ assert self.server_side, f'Wrong value: {self.server_side=}'
+ return (self.data_wrap, self.in_session, str(self), self.hostname, self.ENCODING, self.data_wrap_notes)
+
+ def address_headers(self) -> Tuple[str, ...]:
+ assert self.server_side, f'Wrong value: {self.server_side=}'
+ return ('ID', 'Type', 'In Session', 'Address', 'Hostname', 'Encoding', 'Type Notes')
+
+ def detailed_address(self) -> Tuple[Tuple[Tuple[str, Any], ...], Tuple[str, str]]:
+ assert self.server_side, f'Wrong value: {self.server_side=}'
+ return ((('Type', self.data_wrap),
+ ('In Session', self.in_session),
+ ('Address', str(self)),
+ ('Hostname', self.hostname),
+ ('Encoding', self.ENCODING),
+ ('Type Notes', self.data_wrap_notes),
+ ('Address Type', f'{self.ip.max_prefixlen}-bit IPv{self.ip.version}'),
+ ('Reverse Pointer', self.ip.reverse_pointer),
+ ('Global Address', self.ip.is_global),
+ ('Link Local Address', self.ip.is_link_local),
+ ('Loopback Address', self.ip.is_loopback),
+ ('Multicast Address', self.ip.is_multicast),
+ ('Private Address', self.ip.is_private),
+ ('Reserved Address', self.ip.is_reserved),
+ ('Unspecified Address', self.ip.is_unspecified)),
+ ('Key', 'Value'))
+
+ def close(self) -> None:
+ assert self.conn, f'Missing attribute: {self.conn=}'
+
+ try:
+ self.conn.shutdown(socket.SHUT_RDWR)
+ except Exception:
+ pass
+ finally:
+ self.conn.close()
+
+ def set_conn(self, **kwargs) -> socket.socket:
+ if self.server_side:
+ assert self.is_host, f'Wrong value: {self.is_host=}'
+
+ self.conn = socket.create_server((str(self.ip), self.port), **kwargs)
+ else:
+ self.conn = socket.create_connection((str(self.ip), self.port), **kwargs)
+
+ def set_context(self) -> None:
+ assert self.server_side, f'Wrong value: {self.server_side=}'
+
+ if self.is_host:
+ self.data_wrap = 'HOSTING :: COMP'
+ else:
+ self.data_wrap = 'CONNECTING :: COMP'
+
+ def set_middleware(self) -> None:
+ assert not self.is_host, f'Wrong value: {self.is_host=}'
+
+ self.send_callback = lambda body: zlib.compress(body)
+ self.recv_callback = lambda body: zlib.decompress(body)
+
+ def send(self, obj: Dict[str, Any], callback: Callable[[str, str], None]=None) -> None:
+ assert not self.is_host, f'Wrong value: {self.is_host=}'
+ assert self.conn, f'Missing attribute: {self.conn=}'
+ assert isinstance(obj, dict), f'Wrong type: {obj=}'
+ assert callable(callback) or callback is None, f'Wrong type: {callback=}'
+
+ body = json.dumps(obj, cls=ExtendedEncoder)
+ body = body.encode(self.ENCODING, errors=LIBERAL_ENCODING_ERRORS)
+
+ if self.send_callback:
+ body = self.send_callback(body)
+
+ # if self.server_side:
+ # self.uuid_token = str(uuid.uuid4())
+
+ body_size = len(body)
+ header = str(body_size).ljust(self.SEND_HEADER_SIZE)
+ header = header.encode(DEFAULT_ENCODING, errors=STRICT_ENCODING_ERRORS)
+ # header = (header + self.uuid_token).encode(DEFAULT_ENCODING,
+ # errors=STRICT_ENCODING_ERRORS)
+
+ if self.header_send_callback:
+ header = self.header_send_callback(header)
+
+ if callback:
+ header_size = len(header)
+ callback(Bytes.format(header_size),
+ Bytes.format(body_size),
+ Bytes.format(header_size + body_size))
+
+ self.conn.sendall(header + body)
+
+ def recv(self, callback: Callable[[str, str], None]=None) -> Dict[str, Any]:
+ assert not self.is_host, f'Wrong value: {self.is_host=}'
+ assert self.conn, f'Missing attribute: {self.conn=}'
+ assert callable(callback) or callback is None, f'Wrong type: {callback=}'
+
+ position = 0
+
+ if callback:
+ history = []
+
+ if self.BUFFER_SIZE != self.DEFAULT_BUFFER_SIZE:
+ self.BUFFER_SIZE = self.DEFAULT_BUFFER_SIZE
+
+ while True:
+ buffer = self.conn.recv(self.BUFFER_SIZE)
+
+ assert buffer, f'Missing attribute: {buffer=}'
+
+ buffer_size = len(buffer)
+
+ if position > 0:
+ body[position:position + buffer_size] = buffer
+ position += buffer_size
+ else:
+ header = buffer[:self.RECV_HEADER_SIZE]
+
+ if self.header_recv_callback:
+ header = self.header_recv_callback(header)
+
+ assert isinstance(header, bytes), f'Wrong type: {header=}'
+
+ header = header.decode(DEFAULT_ENCODING, errors=STRICT_ENCODING_ERRORS)
+ # uuid_token = header[self.SEND_HEADER_SIZE:self.RECV_HEADER_SIZE]
+
+ # assert self.UUID_PATTERN.match(uuid_token), f'Wrong value: {uuid_token=}'
+
+ # if self.server_side:
+ # assert uuid_token == self.uuid_token, f'Wrong value: {uuid_token} != {self.uuid_token}'
+ # else:
+ # self.uuid_token = uuid_token
+
+ body_size = int(header[:self.SEND_HEADER_SIZE])
+
+ if body_size >= self.MAX_BUFFER_SIZE:
+ self.BUFFER_SIZE = self.MAX_BUFFER_SIZE
+
+ if callback:
+ bytes_to_recieve = Bytes.format(body_size + self.RECV_HEADER_SIZE)
+
+ body_buffer_size = buffer_size - self.RECV_HEADER_SIZE
+ body = bytearray(body_size)
+ body[:buffer_size] = buffer[self.RECV_HEADER_SIZE:]
+ position += body_buffer_size
+
+ assert position <= body_size, f'Wrong value: {position} > {body_size}'
+
+ if callback:
+ history.append((Bytes.format(buffer_size),
+ Bytes.format(position + self.RECV_HEADER_SIZE),
+ bytes_to_recieve))
+
+ if position == body_size:
+ body = bytes(body)
+
+ if self.recv_callback:
+ body = self.recv_callback(body)
+
+ assert isinstance(body, bytes), f'Wrong type: {body=}'
+
+ body = body.decode(self.ENCODING, errors=LIBERAL_ENCODING_ERRORS)
+ body = json.loads(body, cls=ExtendedDecoder)
+
+ assert isinstance(body, dict), f'Wrong type: {body=}'
+
+ if callback:
+ callback(history)
+
+ return body
+
+class SymmetricSocket(Socket):
+
+ def __init__(
+ self,
+ *args,
+ password: str=None,
+ salt: str=None,
+ **kwargs
+ ) -> None:
+ super().__init__(*args, **kwargs)
+
+ assert isinstance(password, str) or password is None, f'Wrong type: {password=}'
+ assert isinstance(salt, str) or salt is None, f'Wrong type: {salt=}'
+
+ self.password = password
+ self.salt = salt
+
+ def set_context(self) -> None:
+ assert self.server_side, f'Wrong value: {self.server_side=}'
+
+ if self.is_host:
+ self.data_wrap = 'HOSTING :: AES'
+ self.data_wrap_notes = f'{self.password} :: {self.salt}'
+ else:
+ self.data_wrap = 'CONNECTING :: AES'
+ self.data_wrap_notes = 'AES_128_CBC_PKCS7_HMAC_SHA256'
+
+ def set_middleware(self) -> None:
+ assert not self.is_host, f'Wrong value: {self.is_host=}'
+ assert self.password, f'Missing attribute: {self.password=}'
+ assert self.salt, f'Missing attribute: {self.salt=}'
+
+ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives import hashes
+ from cryptography.fernet import Fernet
+
+ secret = PBKDF2HMAC(algorithm=hashes.SHA256(),
+ length=32,
+ salt=self.salt.encode(DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS),
+ iterations=320000,
+ backend=default_backend())
+ secret = secret.derive(self.password.encode(DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS))
+ secret = Fernet(base64.urlsafe_b64encode(secret))
+
+ recv_header_size = ''.ljust(self.RECV_HEADER_SIZE)
+ recv_header_size = recv_header_size.encode(DEFAULT_ENCODING,
+ errors=STRICT_ENCODING_ERRORS)
+ self.RECV_HEADER_SIZE = len(secret.encrypt(recv_header_size))
+ self.header_send_callback = lambda header: secret.encrypt(header)
+ self.header_recv_callback = lambda header: secret.decrypt(header)
+ self.send_callback = lambda body: secret.encrypt(zlib.compress(body))
+ self.recv_callback = lambda body: zlib.decompress(secret.decrypt(body))
+
+class AsymmetricSocket(Socket):
+
+ def __init__(
+ self,
+ *args,
+ public_key: str=None,
+ private_key: str=None,
+ public_key_data: str=None,
+ **kwargs
+ ) -> None:
+ super().__init__(*args, **kwargs)
+
+ assert isinstance(public_key, str) or public_key is None, f'Wrong type: {public_key=}'
+ assert isinstance(private_key, str) or private_key is None, f'Wrong type: {private_key=}'
+ assert isinstance(public_key_data, str) or public_key_data is None, f'Wrong type: {public_key_data=}'
+
+ self.public_key = public_key
+ self.private_key = private_key
+ self.public_key_data = public_key_data
+
+ def set_context(self) -> None:
+ assert self.server_side, f'Wrong value: {self.server_side=}'
+
+ if self.is_host:
+ self.data_wrap = 'HOSTING :: TLS'
+ self.data_wrap_notes = f'{self.public_key} :: {self.private_key}'
+ else:
+ self.data_wrap = 'CONNECTING :: TLS'
+ cipher = self.conn.cipher()
+
+ if cipher:
+ self.data_wrap_notes = ' :: '.join(cipher[:2])
+
+ def set_middleware(self) -> None:
+ if self.server_side:
+ assert self.is_host, f'Wrong value: {self.is_host=}'
+ assert self.public_key, f'Missing attribute: {self.public_key=}'
+ assert self.private_key, f'Missing attribute: {self.private_key=}'
+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ context.load_cert_chain(self.public_key, self.private_key)
+ self.conn = context.wrap_socket(self.conn, server_side=True)
+ else:
+ assert self.public_key_data, f'Missing attribute: {self.public_key_data=}'
+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ context.load_verify_locations(cadata=self.public_key_data)
+ self.conn = context.wrap_socket(self.conn)