diff --git a/thingsboard_gateway/gateway/tb_client.py b/thingsboard_gateway/gateway/tb_client.py index 5364fab8..1866c33c 100644 --- a/thingsboard_gateway/gateway/tb_client.py +++ b/thingsboard_gateway/gateway/tb_client.py @@ -14,16 +14,22 @@ import logging import threading +import random +import string from time import sleep, time +from os.path import exists from ssl import CERT_REQUIRED, PROTOCOL_TLSv1_2 + +from simplejson import dumps, load + from thingsboard_gateway.tb_utility.tb_utility import TBUtility try: - from tb_gateway_mqtt import TBGatewayMqttClient + from tb_gateway_mqtt import TBGatewayMqttClient, TBDeviceMqttClient except ImportError: print("tb-mqtt-client library not found - installing...") TBUtility.install_package('tb-mqtt-client') - from tb_gateway_mqtt import TBGatewayMqttClient + from tb_gateway_mqtt import TBGatewayMqttClient, TBDeviceMqttClient log = logging.getLogger("tb_connection") @@ -38,9 +44,7 @@ class TBClient(threading.Thread): self.__host = config["host"] self.__port = config.get("port", 1883) self.__default_quality_of_service = config.get("qos", 1) - credentials = config["security"] self.__min_reconnect_delay = 1 - self.__tls = bool(credentials.get('tls', False) or credentials.get('caCert', False)) self.__ca_cert = None self.__private_key = None self.__cert = None @@ -51,6 +55,72 @@ class TBClient(threading.Thread): self.__stopped = False self.__paused = False self._last_cert_check_time = 0 + + # check if provided creds or provisioning strategy + if config.get('security'): + self._create_mqtt_client(config['security']) + elif config.get('provisioning'): + if exists(self.__config_folder_path + 'credentials.json'): + with open(self.__config_folder_path + 'credentials.json', 'r') as file: + credentials = load(file) + creds = self._get_provisioned_creds(credentials) + else: + credentials = config['provisioning'] + log.info('Starting provisioning gateway...') + + credentials_type = credentials.pop('type', 'ACCESS_TOKEN') + if credentials_type.upper() == 'ACCESS_TOKEN': + credentials['access_token'] = ''.join(random.choice(string.ascii_lowercase) for _ in range(15)) + elif credentials_type.upper() == 'MQTT_BASIC': + credentials['client_id'] = ''.join(random.choice(string.ascii_lowercase) for _ in range(15)) + credentials['username'] = ''.join(random.choice(string.ascii_lowercase) for _ in range(15)) + credentials['password'] = ''.join(random.choice(string.ascii_lowercase) for _ in range(15)) + elif credentials_type.upper() == 'X509_CERTIFICATE': + self._ca_cert_name = credentials.pop('caCert') + new_cert_path = self.__config_folder_path + 'cert.pem' + new_private_key_path = self.__config_folder_path + 'key.pem' + gen_hash = TBUtility.generate_certificate(new_cert_path, new_private_key_path).decode('utf-8') + credentials['hash'] = gen_hash + else: + raise RuntimeError( + 'Unknown provisioning type (Available options: AUTO, ACCESS_TOKEN, MQTT_BASIC, X509_CERTIFICATE)') + + gateway_name = 'Gateway ' + ''.join(random.choice(string.ascii_lowercase) for _ in range(5)) + prov_gateway_key = credentials.pop('provisionDeviceKey') + prov_gateway_secret = credentials.pop('provisionDeviceSecret') + creds = TBDeviceMqttClient.provision(host=self.__host, + port=1883, + device_name=gateway_name, + provision_device_key=prov_gateway_key, + provision_device_secret=prov_gateway_secret, + **credentials) + + with open(self.__config_folder_path + 'credentials.json', 'w') as file: + creds['caCert'] = self._ca_cert_name + file.writelines(dumps(creds)) + log.info('Gateway provisioned') + + creds = self._get_provisioned_creds(creds) + + self._create_mqtt_client(creds) + else: + raise RuntimeError('Security section not provided') + + # pylint: disable=protected-access + # Adding callbacks + self.client._client._on_connect = self._on_connect + self.client._client._on_disconnect = self._on_disconnect + # self.client._client._on_log = self._on_log + self.start() + + # def _on_log(self, *args): + # if "exception" in args[-1]: + # log.exception(args) + # else: + # log.debug(args) + + def _create_mqtt_client(self, credentials): + self.__tls = bool(credentials.get('tls', False) or credentials.get('caCert', False)) if credentials.get("accessToken") is not None: self.__username = str(credentials["accessToken"]) if credentials.get("username") is not None: @@ -59,11 +129,17 @@ class TBClient(threading.Thread): self.__password = str(credentials["password"]) if credentials.get("clientId") is not None: self.__client_id = str(credentials["clientId"]) - self.client = TBGatewayMqttClient(self.__host, self.__port, self.__username, self.__password, self, quality_of_service=self.__default_quality_of_service, client_id=self.__client_id) + + self.client = TBGatewayMqttClient(self.__host, self.__port, self.__username, self.__password, self, + quality_of_service=self.__default_quality_of_service, + client_id=self.__client_id) if self.__tls: - self.__ca_cert = self.__config_folder_path + credentials.get("caCert") if credentials.get("caCert") is not None else None - self.__private_key = self.__config_folder_path + credentials.get("privateKey") if credentials.get("privateKey") is not None else None - self.__cert = self.__config_folder_path + credentials.get("cert") if credentials.get("cert") is not None else None + self.__ca_cert = self.__config_folder_path + credentials.get("caCert") if credentials.get( + "caCert") is not None else None + self.__private_key = self.__config_folder_path + credentials.get("privateKey") if credentials.get( + "privateKey") is not None else None + self.__cert = self.__config_folder_path + credentials.get("cert") if credentials.get( + "cert") is not None else None self.__check_cert_period = credentials.get('checkCertPeriod', 86400) self.__certificate_days_left = credentials.get('certificateDaysLeft', 3) @@ -80,18 +156,23 @@ class TBClient(threading.Thread): ciphers=None) if credentials.get("insecure", False): self.client._client.tls_insecure_set(True) - # pylint: disable=protected-access - # Adding callbacks - self.client._client._on_connect = self._on_connect - self.client._client._on_disconnect = self._on_disconnect - # self.client._client._on_log = self._on_log - self.start() - # def _on_log(self, *args): - # if "exception" in args[-1]: - # log.exception(args) - # else: - # log.debug(args) + @staticmethod + def _get_provisioned_creds(credentials): + creds = {} + if credentials.get('credentialsType') == 'ACCESS_TOKEN': + creds['accessToken'] = credentials['credentialsValue'] + elif credentials.get('credentialsType') == 'MQTT_BASIC': + creds['clientId'] = credentials['credentialsValue']['clientId'] + creds['username'] = credentials['credentialsValue']['userName'] + creds['password'] = credentials['credentialsValue']['password'] + elif credentials.get('credentialsType') == 'X509_CERTIFICATE': + creds['tls'] = True + creds['caCert'] = credentials['caCert'] + creds['privateKey'] = 'key.pem' + creds['cert'] = 'cert.pem' + + return creds def _check_certificates(self): while not self.__stopped and not self.__paused: