secure-chat/server/main.py
2018-12-02 19:12:37 +01:00

448 lines
15 KiB
Python

#!/usr/bin/env python3
# coding: utf8
import json
import logging
import logging.config
import os
import socket
import threading
try:
# noinspection PyUnresolvedReferences
from Crypto.PublicKey import RSA as RSA
# noinspection PyUnresolvedReferences
from Crypto.Cipher import PKCS1_OAEP as PKCS1_OAEP
from Crypto.Cipher import AES as AES
# noinspection PyUnresolvedReferences,PyProtectedMember
from Crypto.Random._UserFriendlyRNG import get_random_bytes as get_random_bytes
pycryptodome = False
except ModuleNotFoundError: # Pycryptodomex
from Cryptodome.PublicKey import RSA as RSA
from Cryptodome.Cipher import PKCS1_OAEP as PKCS1_OAEP
from Cryptodome.Cipher import AES as AES
from Cryptodome.Random import get_random_bytes as get_random_bytes
pycryptodome = True
def setup_logging(default_path='log_config.json', default_level=logging.INFO, env_key='LOG_CFG'):
"""Setup logging configuration
"""
path = default_path
value = os.getenv(env_key, None)
if value:
path = value
if os.path.exists(path):
with open(path, 'rt') as f:
config = json.load(f)
logging.config.dictConfig(config)
else:
logging.basicConfig(level=default_level)
setup_logging()
log_server = logging.getLogger('server')
debug = log_server.debug
info = log_server.info
warning = log_server.warning
error = log_server.error
critical = log_server.critical
#### Variables ####
HOST = ''
PORT = 8888
BUFFER_SIZE = 4096
CHUNK_SIZE = int(BUFFER_SIZE / 8)
BEGIN_MESSAGE = bytes("debut".ljust(BUFFER_SIZE, ";"), "ascii")
END_MESSAGE = bytes("fin".ljust(BUFFER_SIZE, ";"), "ascii")
VERSION = b"EICP2P2 V1"
REQUEST_TYPE = [
b'ping', b'pingACK', b'updateAsk', b'updateBack', b'transfer', b'register_client', b'registerACK', b'send', b'sendACK',
b'exit', b'RSASend', b'init', b'getUsers', b'getUsersACK',
]
DONE = 0
ERROR = 1
T_NONE = 0b0000000000
T_NODE = 0b0000000001
T_CLIENT = 0b00000010
#### Socket ####
main_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
main_socket.bind((HOST, PORT))
#### Threads ####
class RsaGenThread(threading.Thread):
def __init__(self, difficulty=2):
threading.Thread.__init__(self)
self.difficulty = difficulty
def run(self):
rsa = None
if os.path.isfile("private.pem"):
try:
with open("private.pem", "rb") as keyfile:
rsa = RSA.importKey(keyfile.read())
if not rsa.has_private():
warning("Le fichier clef ne contient pas de clef privée")
raise ValueError
except (IndexError, ValueError):
warning("Fichier clef corrompu")
debug("Suppression du fichier clef corromu")
os.remove("private.pem")
if not os.path.isfile("private.pem"): # We're not using if/else because we may delete the file in the
# previous if statement
rsa = RSA.generate(BUFFER_SIZE + 256 * self.difficulty)
with open("private.pem", "wb") as keyfile:
keyfile.write(rsa.exportKey())
with open("public.pem", "wb") as keyfile:
keyfile.write(rsa.publickey().exportKey())
return rsa
class ServerThread(threading.Thread):
"""Main tread for server"""
def __init__(self, socket, ip, port):
threading.Thread.__init__(self) # initialisation du thread
self.socket = socket
self.ip = ip
self.port = port
t = RsaGenThread()
self.rsa_key = None
if os.path.isfile("private.pem"):
try:
with open("private.pem", "rb") as keyfile:
self.rsa_key = RSA.importKey(keyfile.read())
if not self.rsa_key.has_private():
warning("Le fichier clef ne contient pas de clef privée")
raise ValueError
except (IndexError, ValueError):
warning("Fichier clef corrompu")
debug("Suppression du fichier clef corromu")
os.remove("private.pem")
if not os.path.isfile("private.pem"): # We're not using if/else because we may delete the file in the
# previous if statement
debug("Generate new rsa key")
self.rsa_key = RSA.generate(BUFFER_SIZE + 256 * self.difficulty)
with open("private.pem", "wb") as keyfile:
keyfile.write(self.rsa_key.exportKey())
with open("public.pem", "wb") as keyfile:
keyfile.write(self.rsa_key.publickey().exportKey())
debug("RSA key loaded")
self.clients = {}
self.nodes = {
self.rsa_key.publickey().exportKey(): None,
}
def register_client(self, rsa_client, client_thread):
self.clients.update({rsa_client: (self.rsa_key, client_thread)})
return self.rsa_key.publickey().exportKey(), rsa_client
def send_to(self, id_dest, to_send):
if id_dest not in self.clients.keys():
return b"Erreur client inconnu"
if self.clients[id_dest][0] is None:
return self.clients[id_dest].send_to_me(to_send)
else:
return self.nodes[self.clients[id_dest]].transfer(id_dest, to_send)
class ClientThread(threading.Thread):
"""Main thread, for each client"""
def __init__(self, clientsocket, ip_client, port, server):
"""Create ClientThread object
:param clientsocket: Client's socket
:param ip_client: Client's ip address
:param port: Client's connection PORT
:param server: Server thread
:type clientsocket: socket.socket
:type ip_client: str
:type port: int
:type server: ServerThread
:return: Nothing
:rtype: NoneType"""
debug("Creation du thread pour %s" % ip_client)
threading.Thread.__init__(self) # initialisation du thread
self.client = clientsocket
self.ip = ip_client
self.port = port
self.running = True
self.status = None
self.rsa_client = None
self.aes_key = get_random_bytes(32)
self.type = T_NONE
self.server = server
debug("Creation du thread pour %s reussie" % ip_client)
def initialize(self):
"""Initialize connection with client
:rtype: NoneType
:return: Nothing"""
# Receive message
message = self.receive()
header = self.extract_header(message)
content = message[BUFFER_SIZE:]
if header.get(b"type", None) is None:
debug("The type field is not in the header")
self.send(b"Error")
return
if header.get(b"from", None) is None:
debug("The from field is not in the header")
self.send(b"error")
return
if self.status is None and header[b"type"] != b"RSASend":
debug("Requête différente de RSASend avec une connection non initialisée")
self.send(b"Error")
return
if header[b"type"] == b"RSASend":
self.type = T_CLIENT if header[b"from"] == b"client" else T_NODE
debug("Réception de la clef RSA de %s", self.ip)
self.rsa_client = content
header = self.gen_header(b"init")
content = self.aes_key
self.send_rsa(header + content)
return
@staticmethod
def extract_header(data):
"""Extract header from data
:param data: Data to extract header
:type data: bytes
:return: Dictionary with header datas
:rtype: dict{bytes: bytes}"""
if len(data) > BUFFER_SIZE:
debug("Header too long")
data = data[:BUFFER_SIZE]
data_lines = data.split(b'\n')
if data_lines[0] != VERSION:
raise ValueError("Version is incorrect.")
return {
l.split(b": ")[0]: l.split(b": ")[1].rstrip(b";") for l in data_lines[1:]
}
@staticmethod
def gen_header(type_, to_=None, from_=None):
"""Generate header
:param type_: Request type
:param to_: `to` field in header, cf ../RFC8497.md
:param from_: `from` field in header, cf ../RFC8497.md
:type type_: bytes
:type to_: bytes
:type from_: bytes
:raise ValueError: `type_` is not a valid request type
:return: header
:rtype: bytes"""
if type_ not in REQUEST_TYPE:
raise ValueError("Unknown request type")
header = VERSION + b"\ntype: " + type_
if to_:
header += b"\nto: " + to_
if from_:
header += b"\nfrom: " + from_,
return header.ljust(BUFFER_SIZE, b';')
################################################ COMMUNICATION WITH AES ############################################
def send_aes(self, to_send, key=None):
"""Send message with aes encryption
:param to_send: Message to send
:type to_send: bytes
:param key: key to replace self.aes_key
:type key: bytes
:rtype: NoneType
:return: Nothing"""
debug(b"Send with AES encryption: " + to_send + bytes(str(self.ip), "ascii"))
if key is None:
key = self.aes_key
if key is None:
info("AES key not generated, connection failure.")
self.client.send(b"Error")
return
# Get RSA key
aes_object = AES.new(key, AES.MODE_ECB)
encrypted = b""
for to_send_text in [to_send[i:i + 32] for i in range(0, len(to_send), 32)]:
encrypted += aes_object.encrypt(to_send_text.ljust(32, b"\x00"))
self.send(encrypted)
return None
def receive_aes(self, key=None):
"""Receive message with aes encryption
:param key: key to replace self.aes_key
:type key: bytes
"""
to_decrypt = self.receive()
if key is None:
key = self.aes_key
if key is None:
info("AES key not generated, connection failure.")
self.client.send(b"Error")
return
aes_object = AES.new(key, AES.MODE_ECB)
decrypted = b""
for block in [to_decrypt[i:i + 32] for i in range(0, len(to_decrypt), 32)]:
decrypted += aes_object.decrypt(block)
return decrypted.rstrip(b"\x00")
################################################ COMMUNICATION WITH RSA ############################################
def send_rsa(self, to_send, key=None):
"""Send message with rsa encryption
:param to_send: Message to send
:type to_send: bytes
:param key: key to replace self.client_key
:type key: bytes
:rtype: NoneType
:return: Nothing"""
debug(b"Send with RSA encryption: " + to_send + bytes(str(self.ip), "ascii"))
if key is None:
key = self.rsa_client
if key is None:
info("RSA key not received, connection failure.")
self.client.send(b"Error")
return
# Get RSA key
recipient_key = RSA.importKey(key)
# RSA encryption object
cipher_rsa = PKCS1_OAEP.new(recipient_key)
encrypted = b""
for to_send_text in [to_send[i:i + CHUNK_SIZE] for i in range(0, len(to_send), CHUNK_SIZE)]:
encrypted += cipher_rsa.encrypt(to_send_text)
self.send(encrypted)
return None
############################################ COMMUNICATION WITHOUT CRYPTING ########################################
def receive(self):
"""Receive message from connection
:rtype: bytes
:return: Message's content"""
chunk = bytes("", "ascii") # Temp variable to store received datas
while chunk != BEGIN_MESSAGE:
chunk = self.client.recv(BUFFER_SIZE)
content = b''
while chunk != END_MESSAGE:
chunk = self.client.recv(BUFFER_SIZE)
# Get only interesting chucks
if chunk != END_MESSAGE:
# Get content part
# int.from_bytes(chunk[:2], byteorder='big') == Get content size
content += chunk[2:int.from_bytes(chunk[:2], byteorder='big') + 2]
debug(b"Received from" + bytes(str(self.ip), 'ascii') + b" : " + content)
return content
def send(self, to_send):
"""Send message to connection
:param to_send: message to send
:type to_send: bytes
:return: Nothing
:rtype: NoneType"""
debug(b"Send " + to_send + b" to " + bytes(str(self.ip), "ascii"))
# Sending the message start
self.client.send(BEGIN_MESSAGE)
i = 0
for to_send_text in [to_send[i:i + BUFFER_SIZE - 2] for i in range(0, len(to_send), BUFFER_SIZE - 2)]:
self.client.send(
(len(to_send_text)).to_bytes(2, byteorder='big') # Size of the message contained by the chunk
+ to_send_text.ljust(BUFFER_SIZE - 2, bytes(1)) # Content of the chunk
)
i += 1
# Sending the message stop
self.client.send(END_MESSAGE)
return None
def send_users(self):
self.send_aes(self.gen_header(type_=b"getUsersACK")+b"%!!%".join(list(self.server.clients.keys())))
def register_client(self):
"""Register client
:rtype: NoneType
:return: Nothing"""
self.server.register_client(self.rsa_client, self)
id_noeud, id_client = self.server.rsa_key.publickey().exportKey(), self.rsa_client
self.send_aes(self.gen_header(type_=b"registerACK") + id_noeud + b"{%=&%&=%}" + id_client)
def send_to_me(self, to_send):
"""Receive message from other poeple
:param to_send: Message to send to client
:type to_send: bytes
:rtype: NoneType
:return: Nothing"""
self.send(self.gen_header(type_=b"send", to_=self.rsa_client) + to_send)
return
def send_to_other(self, id_dest, to_send):
"""Send message to other client
:param id_dest: id of receiver
:param to_send: Message to send
:return: Nothing
:rtype: NoneType"""
server_response = self.server.send_to(id_dest, to_send)
self.send(self.gen_header(type_=b"sendACK"), server_response)
return
def run(self): # main de la connection du client
"""Run thread mainloop
:return: Nothing
:rtype: NoneType"""
info(self.ip + "connected, initialize connection...")
self.initialize()
info(self.ip + "connection initialized.")
while self.running:
data = self.receive_aes()
header = self.extract_header(data)
print(header)
print(data)
print(self.rsa_client)
if header[b"type"] == b"register_client":
self.register_client()
elif header[b"type"] == b"getUsers":
self.send_users()
elif
self.client.close()
if __name__ == "__main__":
clients = []
server = ServerThread(main_socket, ip=HOST, port=PORT)
while True:
main_socket.listen(1) # Waiting for incoming connections
client_socket, (ip, PORT) = main_socket.accept()
newClient = ClientThread(client_socket, ip, PORT, server)
newClient.start()
clients.append(newClient)