#!/usr/bin/python3

import time
import serial
import regex 
import subprocess
import logging
import requests
import zipfile
import shutil
import os
import sys
from enum import Enum, auto

# TODO:删除多余的库
from math import log
from pathlib import Path
from queue import Queue, Empty
from threading import Thread

# TODO:提前安装缺少的库

NEW_FIRMWARE_LIST = [
                "EC20CEFILGR06A13M1G",
                "EC20CEHDLGR06A14M1G",
                "EC25AFXGAR07A04M1G_01.010.01.010",
                "EC25EUXGAR08A19M1G_A0.200.A0.200",
                "EG25GGBR07A08M2G_A0.204.A0.204"
                ]

QFIREHOSE_DOWNLOAD_LINK = [
                "https://vip.123pan.cn/1826505135/15612107",
                "https://vip.123pan.cn/1826505135/15602557"
                ]

NEW_FIRMWARES_DOWNLOAD_LINK = [
                "https://vip.123pan.cn/1826505135/15612088",
                "https://vip.123pan.cn/1826505135/15612089",
                "https://vip.123pan.cn/1826505135/15612092",
                "https://vip.123pan.cn/1826505135/15612093",
                "https://vip.123pan.cn/1826505135/15612094"
                              ]

NEW_FIRMWARE_MD5_LIST = [
                "1647b0be0abde2411077372ff5e21984",
                "16fd8269d7a6332cd4fd982e470b7ce1",
                "c4af0a5f517006a03959bdbf4500e7ed",
                "2d448c920f89d579a7b9c61908bcac76",
                "e9b5d5726df057a20d1ac0a7a34f3cd3"
                        ]

DOWNLOAD_DIR = "/tmp/quectel"
MODEM_STATUS = None
BAUD = 9600
PORT = "/dev/ttyUSB3"
EXPECTED_PROMPT = b'OK\r\n'


ANSI_REGEX = regex.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
MAC_REGEX = regex.compile(r"(?<!0x)[0-9a-fA-F]{2}")

def escape_ansi(line):
    return ANSI_REGEX.sub("", line)

class ReadTimeoutError(Exception):
    def __init__(self, content):
        super().__init__(self)
        self.content = content

    def __str__(self):
        return super().__str__(self) + "\n".join(self.content)

class PrintMessage:
    def __init__(self, cli):
        self.cli = cli

    def log_info(self, msg):
        self.cli.logger.info(msg)

    def log_error(self, msg):
        self.cli.logger.error(msg)

    def print_info(self, msg):
        self.cli.print("")
        self.cli.print(f"[INFO]: {msg}")

    def print_error(self, msg):
        self.cli.print("")
        self.cli.print(f"[ERROR]: {msg}")

class LoggingOperationError(Exception):
    pass

class SubprocessError(Exception):
    pass

class ReadReturn(Enum):
    TIMEOUT = auto()
    PROMPT = auto()
    TEXT = auto()


class Quectel:
    COMMAND_LATENCY = 0.1

    def __init__(self, cli, baud=BAUD, port=PORT, initial_timeout=3):
        self._ser = serial.Serial()
        self._ser.baudrate = baud
        self._ser.port = port
        self._ser.timeout = initial_timeout
        self.cli = cli

    def open(self):
        self._ser.open()
    
    def is_open(self):
        return self._ser.is_open if self._ser else False

    def close(self):
        self._ser.close()

    def __del__(self):
        self.close()

    def send_command(self, command):
        at_cmd = command.strip() + "\r\n"
        at_cmd = at_cmd.encode()
        self.cli.logger.info(f"SEND: {at_cmd}".strip("\n"))
        self._ser.write(at_cmd)
        time.sleep(self.COMMAND_LATENCY)

    def send_break(self):
        self._ser.reset_output_buffer()
        time.sleep(self.COMMAND_LATENCY)
        self.send_command("\x03\n")
        self._ser.reset_input_buffer()

    def read_to_prompt(self, timeout=3, prompt=None):

        self._ser.timeout = timeout
        lines = []
        try:
            count = 0
            while (line_content := self._read_line(prompt=prompt))[0] != ReadReturn.TIMEOUT:
                (last_rc, line) = line_content
                yield line

                if last_rc == ReadReturn.PROMPT:
                    break
                lines.append(line)
            else:
                self.send_break()
                raise ReadTimeoutError(lines)
        except ReadTimeoutError:
            self.send_break()
            raise ReadTimeoutError

    def _read_line(self, prompt=None):
        if prompt is None:
            prompt = EXPECTED_PROMPT
        line = self._ser.readline()
        self.cli.logger.info(f"READ: {line}".strip("\n"))
        rc = ReadReturn.TEXT

        if not line:
            rc = ReadReturn.TIMEOUT
        elif line == prompt:
            rc = ReadReturn.PROMPT
        self.cli.logger.info(f"rc: {rc}\tline: {line}")
        return rc, line

class MfgLogging:
    # LOG_DIR = Path(DOWNLOAD_DIR).parent.joinpath("logs")
    LOG_DIR = Path(__file__).parent

    def __init__(self, debug=False):
        self.handler = None
        self.logger = logging.getLogger("")
        if debug:
            self.logger.setLevel(logging.DEBUG)
        else:
            self.logger.setLevel(logging.INFO)

    def add_file_handler(self):
        if self.handler is None:
            #self.LOG_DIR.mkdir(exist_ok=True, parents=True)
            formatter = logging.Formatter(
                fmt="{asctime}.{msecs:<4.0f} | {name:<12} | {levelname:<8} | {message}",
                datefmt="%m/%d %H:%M:%S",
                style="{",
            )
            handler = logging.FileHandler(
                f"{self.LOG_DIR}/active.log",
                mode="w",
            )

            handler.setFormatter(formatter)
            self.handler = handler
            self.logger.addHandler(handler)
        else:
            raise Exception("Previous file not close")

    def close_file_handler(self, serial_number, suppress_error=False):
        if self.handler is not None:
            handler = self.handler
            self.handler = None
            handler.close()

            target_logs_path = self.LOG_DIR
            log_name = f"update_4g_firmware_{serial_number}.log"

            if not target_logs_path.exists():
                target_logs_path.mkdir(parents=True)

            shutil.move(f"{self.LOG_DIR}/active.log", f"{target_logs_path}/{log_name}")
        
        elif not suppress_error:
            raise LoggingOperationError("No active log file")

class UserCLI:
    def __init__(self):
        self.logger = logging.getLogger("log")

    def clear(self):
        print("\033[2J\033[0;0H")

    def print(self, message):
        if not message.isspace() or message == "":
            self.logger.info(message)
        print(message)

    def error(self, message):
        lines = message.split("\n")
        if len(lines) > 1:
            self.logger.error(lines[0])
            for line in lines[1:]:
                self.logger.error(f"\t{line}")
        else:
            self.logger.error(message)
        print(message)

    def input(self, prompt):
        self.logger.info(prompt)
        return input(prompt)

    def subprocess(self, commands):
        popen = subprocess.Popen(
            commands,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True,
        )
        print_queue = Queue()

        stdout_t = Thread(
            target=self.threaded_log,
            args=(popen.stdout, print_queue, self.logger.info),
        )
        stdout_t.daemon = True

        stderr_t = Thread(
            target=self.threaded_log,
            args=(popen.stderr, print_queue, self.logger.error),
        )
        stderr_t.daemon = True

        stdout_t.start()
        stderr_t.start()

        while (rc := popen.poll()) is None:
            try:
                while not print_queue.empty():
                    line = print_queue.get(timeout=0.1).strip("\n")
                    if line.lstrip().startswith("Verifying:") or line.lstrip().startswith("Writing:") or line.lstrip().startswith("Download"):
                        print(f"\r{line.strip()}", end="")
                    else:
                        print(line)
            except Empty:
                pass

        while not print_queue.empty():
            line = print_queue.get(timeout=0.1).strip("\n")
            print(f"\r{line}", end="",flush=True)

        self.logger.info(f"Returncode: {rc}")
        stdout_t.join()
        stderr_t.join()
        print("")
        print(rc)
        if rc != 0:
            print(f"stderr:{stderr_t}")
            raise SubprocessError(f"Error running subprocess:{stdout_t}")
        return popen.returncode

    @staticmethod
    def threaded_log(pipe, queue, log_fn):
        for line in iter(pipe.readline, ""):
            log_fn(escape_ansi(line).strip("\n"))
            queue.put(line)
        pipe.close()

    def verify(self, prompt):
        while (verification := input(f"{prompt} y/n?: ")) not in set(["y", "n"]):
            self.logger.warning(f"bad input: {verification}")
        self.logger.info(verification)
        return TestResult.from_expected("y", verification)

    def result(self, test_name, results, value=None):
        result = results.name
        value = results.value

        if value is None or result == "PASS":
            message = f"\033[32m{test_name:<32}: {result}\033[0m"
        else:
            message = f"\033[31m{test_name:<32}: {result}\033[0m"
        self.logger.info(message)
        print(message)
        return result

# first phase

# 有的设备qcserial文件不存在
def check_modemm_and_qcserial(pri):
    try:
        result = subprocess.run(['systemctl', 'is-active', 'ModemManager', '--quiet'])
        if result.returncode == 0:
            modem_status = "active"
            pri.log_info("ModemManager is running. Stopping it...")
            subprocess.run(['sudo', 'systemctl', 'stop', 'ModemManager'])
        else:
            modem_status = "not active"
            pri.log_info("ModemManager is not running.")
        global MODEM_STATUS
        MODEM_STATUS=modem_status

    except Exception as e:
        pri.print_error(f"An error occurred while handling ModemManager :{str(e)}")
        raise Exception(f"An error occurred while handling ModemManager :{str(e)}")

def restore_modem_and_qcserial(pri, modem_status):
    try:
        if modem_status == "active":
            pri.log_info("ModemManager is not running. Starting it...")
            subprocess.run(['sudo', 'systemctl', 'start', 'ModemManager'], check=True)
        else:
            pri.log_info("ModemManager is already not running.")

    except Exception as e:
        pri.print_error(f"An error occurred while restoring ModemManager: {str(e)}")
        raise Exception(f"An error occurred while restoring ModemManager: {str(e)}")

# second phase
def check_lsusb_first(pri):
    try:
        output = subprocess.check_output(["lsusb"], text=True)
        lines = output.splitlines()
        for line in lines:
            lower_line = line.lower()
            # print(lower_line)
            if "lte" in lower_line:
                if "quectel" in lower_line:
                    pri.print_info(f"Find: {line}")
                    break
                else:
                    pri.print_error(f"No find Quectel: {line}")
                    raise Exception(f"No find Quectel: {line}")
        else:
            pri.print_error(f"No LTE or Quectel device found in lsusb output.")
            raise Exception("No LTE or Quectel device found in lsusb output.")
    except Exception as e:
        pri.print_error(f"lsusb Fail: {str(e)}")
        raise Exception(f"lsusb Fail: {str(e)}")

# 首先检查ttyUSB是否是4个,如果不是,打印提示 去除多余的USB串口设备
def check_ttyusb3_first(pri):
    dev_path = "/dev/ttyUSB3"

    try:
        result = subprocess.run(["ls /dev/ttyUSB*"], capture_output=True, shell=True, text=True)
        ttyusb_devices = result.stdout.splitlines()
        if len(ttyusb_devices) != 4:
            pri.print_error(f"Expected 4 ttyUSB devices, but found {len(ttyusb_devices)}.")
            raise Exception(f"Error: Expected 4 ttyUSB devices, but found {len(ttyusb_devices)}.")
    except subprocess.CalledProcessError as e:
        pri.print_error(f"Error checking ttyUSB devices: {e}")
        raise Exception(f"Error checking ttyUSB devices: {e}")

    if not os.path.exists(dev_path):
        pri.print_info(f"No find {dev_path}!")
        raise Exception(f"No find {dev_path}!")
    pri.log_info(f"[INFO] {dev_path} exists.")

def conver_at_res_list(lines):
    if not lines:
       raise Exception("Input lines cannot be empty.")

    lines_list = []
    for line in lines:
        if isinstance(line, bytes):
            line_str = line.decode(errors="ignore").strip()
        else:
            line_str = line.strip()
        if line_str:
            lines_list.append(line_str)
    return lines_list

def get_res_with_keyword(lines_list, command):
    sn = None
    total_lines = len(lines_list)
    for i, line in enumerate(lines_list):
        if command in line:
            if i + 1 < total_lines:
                next_line = lines_list[i+1].strip()
                if next_line is not None:
                    sn = next_line
                    break
    return sn

def get_quectel_gsn(term, pri):
    gsn = None
    command = "at+gsn"
    term.send_command(command)

    lines_list = conver_at_res_list(term.read_to_prompt())
    sn = get_res_with_keyword(lines_list, command)
    if not sn:
        pri.print_error(f"GSN Return Value is None!")
        raise Exception("GSN Return Value is None!")
    return sn

# third phase
def get_firmware_version(term, pri):
    command = "at+qgmr"
    term.send_command(command)

    lines_list = conver_at_res_list(term.read_to_prompt())
    qgmr = get_res_with_keyword(lines_list, command)
    if not qgmr:
        pri.print_error(f"QGMR Return Value is None!")
        raise Exception("QGMR Return Value is None!")
    return qgmr

def download_file(url, destination, pri):
    response = requests.get(url)
    with open(destination, 'wb') as file:
        file.write(response.content)
    pri.print_info(f"Downloaded file to {destination}")

# TODO:下载后需要增加下执行权限
def download_qfirehose(pri):
    arch = subprocess.getoutput("uname -m")
    
    if arch == 'armv7l':
        download_url = QFIREHOSE_DOWNLOAD_LINK[0]
    elif arch == 'aarch64':
        download_url = QFIREHOSE_DOWNLOAD_LINK[1]
    else:
        pri.print_error("uname -m error")
        raise Exception("uname -m error")
    
    expected_filename = "QFirehose"
    if not os.path.exists(DOWNLOAD_DIR):
        os.makedirs(DOWNLOAD_DIR)
    destination = os.path.join(DOWNLOAD_DIR, expected_filename)
    download_file(download_url, destination, pri)

    file_exists = os.path.exists(destination)
    if not file_exists:
        pri.print_error(f"Download QFirehose error! File does not exist.Please try again.")
        raise Exception("No QFIrehose tool")
    subprocess.run(['sudo', 'chmod', '+x', f"{destination}"])

# TODO: 下载后要比较md5
def calculate_md5(file_path, pri):
    result = subprocess.run(['md5sum', file_path], capture_output=True, text=True)
    if result.returncode != 0:
        pri.print_error(f"Failed to calculate MD5 for {file_path}")
        raise Exception(f"Failed to calculate MD5 for {file_path}")
    pri.print_info(f"MD5: {file_path} -> {result.stdout.split()[0]}")
    return result.stdout.split()[0]

def download_firmware(qgmr, pri):
    index = None
    if qgmr is None or len(qgmr) == 0:
        pri.print_error(f"qgmr is None!")
        return False

    if qgmr in NEW_FIRMWARE_LIST:
        pri.print_info("Firmware is the Newest!")
        return True

    prefix_length = 14 if len(qgmr) < 25 else 13
    prefix = qgmr[:prefix_length]

    for i, item in enumerate(NEW_FIRMWARE_LIST):
        if item[:prefix_length] == prefix:
            pri.log_info(f"Found prefix match. Index: {i}")
            index = i
            break
    else:
        pri.print_error(f"No expect firmware version!")
        raise Exception("No expect firmware version!")
    pri.print_info(f"Expect firmware version:{NEW_FIRMWARE_LIST[index]} url:{NEW_FIRMWARES_DOWNLOAD_LINK[index]}")

    download_url = NEW_FIRMWARES_DOWNLOAD_LINK[index]
    destination = f"{DOWNLOAD_DIR}/{NEW_FIRMWARE_LIST[index]}.zip"
    download_file(download_url, destination, pri)

    downloaded_md5 = calculate_md5(destination, pri)
    expected_md5 = NEW_FIRMWARE_MD5_LIST[index]
    if downloaded_md5 != expected_md5:
        pri.print_error(f"MD5 mismatch! Expected: {expected_md5}, Found: {downloaded_md5}. Please check!")
        raise Exception(f"MD5 mismatch! Expected: {expected_md5}, Found: {downloaded_md5}")
    else:
        pri.print_info(f"MD5 check passed for {NEW_FIRMWARE_LIST[index]}")
    
    return f"{NEW_FIRMWARE_LIST[index]}.zip"

def unzip_file(zip_file_name, pri):
    zip_file_path = os.path.join(DOWNLOAD_DIR, zip_file_name)
    if not os.path.exists(zip_file_path):
        pri.print_error(f"{zip_file_name} does not exist! Please try again!")
        raise Exception(f"{zip_file_name} does not exist.")
    
    directory = os.path.dirname(zip_file_path)
    try:
        result = subprocess.run(['unzip', '-o', zip_file_path, '-d', directory], check=True, text=True, capture_output=True)

        if result.returncode == 0:
            pri.print_info(f"File {zip_file_path} extracted to {directory}.")
            return os.path.basename(directory)
        else:
            pri.print_error(f"Error during extraction: {result.stderr}. Please check {DOWNLOAD_DIR}!")
            raise Exception(f"Error during extraction: {result.stderr}")
    except Exception as e:
        pri.print_error(f"Error: An exception occurred while unzipping the file: {str(e)}")
        raise Exception(f"Error: An exception occurred while unzipping the file: {str(e)}")

def qfirehose(que, cli, new_firmware):
    if que.is_open():
        que.close()
    qf_rc = cli.subprocess(["sudo", f"{DOWNLOAD_DIR}/QFirehose", "-f", f"{DOWNLOAD_DIR}/{new_firmware}"])
    return qf_rc

# forth phase
# check_lsusb 如果升级失败,lsusb显示的结果为
def check_lsusb_second(pri, max_retries=40):
    retry_count = 0
    
    while retry_count < max_retries:
        try:
            output = subprocess.check_output(["lsusb"], text=True)
            lines = output.splitlines()

            for line in lines:
                lower_line = line.lower()
                if "lte" in lower_line and "quectel" in lower_line:
                    pri.print_info(f"Find: {line}")
                    return True

            retry_count += 1
            time.sleep(1)

        except subprocess.CalledProcessError as e:
            pri.print_error(f"[ERROR] lsusb command failed: {e}")
            return False
        except Exception as e:
            pri.print_error(f"[ERROR] lsusb Fail: {e}")
            return False

    pri.print_error(f"Device not found after {max_retries} attempts. Please check lsusb or reboot device.")
    raise Exception("Time out wait lsusb")

# check ttyUSB
def check_ttyusb3_second(pri):
    try:
        dev_path = "/dev/ttyUSB3"
        if os.path.exists(dev_path):
            pri.print_info(f"{dev_path} has been found.")
            return True
        else:
            return False
    except Exception as e:
        pri.print_error(f"check_ttyUSB3_second error: {e}")
        raise Exception(f"check_ttyUSB3_second error: {e}")

def check_update(qgmr, qgmr_second, qgmr_list, pri):
    if len(qgmr_second) != len(qgmr):
        return False

    if len(qgmr) < 20:
        if qgmr_second[:14] != qgmr[:14]:
            return False
    else:
        if qgmr_second[:13] != qgmr[:13]:
            return False

    if len(qgmr_second) < 20:
        if qgmr_second in qgmr_list:
            pri.print_info(f"Current firmware version: {qgmr_second}")
            update_result = True
            return True
    else:
        qgmr_second_prefix = qgmr_second[:24]
        for item in qgmr_list:
            if item[:24] == qgmr_second_prefix:
                pri.print_info(f"Current firmware version: {item}")
                update_result = True
                return True

    pri.print_error(f"No matching update found.update failed")
    return False

def checke_first_fwver(qgmr, pri):
    if len(qgmr) < 20:
        if qgmr in NEW_FIRMWARE_LIST:
            pri.print_info(f"Firmware is the Newest.Pass")
            return True
    else:
        prefix = qgmr[:24]
        for i, item in enumerate(NEW_FIRMWARE_LIST):
            if item[:24] == prefix:
                pri.print_info(f"Firmware is the Newest.Pass")
                return True
        return False


if __name__ == "__main__":
    c_log_control = MfgLogging(debug=False)
    c_log_control.add_file_handler()
    cli = UserCLI()
    que = Quectel(cli=cli)
    pri = PrintMessage(cli=cli)
    gsn = None
    try:
        gsn = None
        qf_rc = None
        update_result = False

        que.open()
        que.send_command("ate")
        que.send_break()
        time.sleep(1)

        # first phase
        check_modemm_and_qcserial(pri)

        # second phase
        check_lsusb_first(pri)
        check_ttyusb3_first(pri)
        gsn = get_quectel_gsn(que, pri)
        pri.print_info(f"Has get gsn:{gsn}")

        # third phase
        qgmr = get_firmware_version(que, pri)
        if checke_first_fwver(qgmr, pri):
            sys.exit(0)

        pri.print_info(f"Has get qgmr:{qgmr}")
        download_qfirehose(pri)
        firmware_zip_name = download_firmware(qgmr, pri)
        if firmware_zip_name == True:
            pri.print_info("The firmware is the newest!")
            sys.exit(0)
        elif firmware_zip_name == False:
            pri.print_info("ERROR, Program exit.")
            sys.exit(1)
        unzip_firmware_dir = unzip_file(firmware_zip_name, pri)
        pri.print_info(f"Firmware to be used:{unzip_firmware_dir}")
        firmware_name = firmware_zip_name[:-4]
        pri.print_info(f"{firmware_name}:firmware dir")
        qf_rc = qfirehose(que, cli, firmware_name)

        pri.print_info(f"Update result:{qf_rc}")

        if qf_rc == 0:
            pri.print_info(f"QFirehose exec success!")
        else:
            pri.print_info(f"QFirehose exec fail!")

        #forth phase
        check_lsusb_second(pri)
        check_ttyusb3_second(pri)
        if not que.is_open():
            que.open()
            que.send_command("ate")
            que.send_break()
            time.sleep(1)
        time.sleep(5)
        que.send_break()
        qgmr_second = get_firmware_version(que, pri)
        pri.log_info(f"Primal qgmr: {qgmr}")
        pri.log_info(f"Lase qgmr: {qgmr_second}")
        test_res = check_update(qgmr, qgmr_second, NEW_FIRMWARE_LIST, pri)
        if test_res:
            pri.print_info(f"Update Success")
        else:
            pri.print_error(f"Update Fail")
       
    except Exception as e:
        raise Exception(f"{e}")
    finally:
        restore_modem_and_qcserial(pri, modem_status=MODEM_STATUS)
        if que.is_open():
            que.close()
        c_log_control.close_file_handler(serial_number=gsn)
