# -*- coding:utf-8 -*-
import datetime
import logging
import os
import time
from logging import handlers

from flask import current_app
from flask.helpers import find_package


class Logger(object):
    def __init__(self, filename, level='info', when='D', backCount=3, fmt='[%(asctime)s] - %(levelname)s: %(message)s'):
        self.logger = logging.getLogger(filename)
        format_str = logging.Formatter(fmt)  # 设置日志格式
        self.logger.setLevel(level)  # 设置日志级别
        # sh = logging.StreamHandler()  # 往屏幕上输出
        # sh.setFormatter(format_str)  # 设置屏幕上显示的格式
        th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
                                               encoding='utf-8')  # 往文件里写入#指定间隔时间自动生成文件的处理器
        # 实例化TimedRotatingFileHandler
        # interval是时间间隔，backupCount是备份文件的个数，如果超过这个个数，就会自动删除，when是间隔的时间单位，单位有以下几种：
        # S 秒
        # M 分
        # H 小时、
        # D 天、
        # W 每星期（interval==0时代表星期一）
        # midnight 每天凌晨
        th.setFormatter(format_str)  # 设置文件里写入的格式
        # self.logger.addHandler(sh)  # 把对象加到logger里
        self.logger.addHandler(th)

    def error(self, msg, *args, **kwargs):
        return self.logger.error(msg, *args, **kwargs)

    def info(self, msg, *args, **kwargs):
        return self.logger.info(msg, *args, **kwargs)

    def debug(self, msg, *args, **kwargs):
        return self.logger.debug(msg, *args, **kwargs)

    def warning(self, msg, *args, **kwargs):
        return self.logger.warning(msg, *args, **kwargs)

    def critical(self, msg, *args, **kwargs):
        return self.logger.critical(msg, *args, **kwargs)


prefix, package_path = find_package(current_app.import_name)
log_path = os.path.join(package_path, 'DB', 'upgrade_log.log')

log = Logger(log_path, level=logging.DEBUG)


def get_upgrade_item(upgrade_sql) -> list:
    '''
    获取升级脚本的子项
    :param upgrade_sql:
    :return:
    '''
    sql = upgrade_sql[upgrade_sql.index('--BATCHUPGRADE'):upgrade_sql.index('--ENDBATCHUPGRADE')]
    sql_list = sql.split('--UPGRADEITEM:')
    result_list = []

    for sql_item in sql_list:
        _dict = {}
        if not sql_item:
            continue
        _dict['title'] = sql_item.split('\n')[0]
        _dict['script'] = '--' + sql_item
        result_list.append(_dict)

    return result_list


def exec_upgrade_sql(sql, message):
    '''
    执行升级的脚本
    :param sql:
    :param message:
    :return:
    '''
    from huansi_utils.db.db import new_session
    session = new_session(begin=True)
    try:
        # 每个事务增加等待
        session.exec_sql("EXEC dbo.sppbWaitFor @sKey='db_upgrade'")
        sql_item_list = get_upgrade_item(sql)
        count = len(sql_item_list)
        for i, sql_item in enumerate(sql_item_list):
            i += 1
            # print(message)
            log.info(message)
            # print(f"     子项:({i}/{count}):{sql_item['title']}")
            log.info(f"     子项:({i}/{count}):{sql_item['title']}")
            start = time.clock()
            # 会解析错误，暂时用最原生的方法解决
            exec_sql = sql_item['script'].replace(':', "\:")
            session.exec_sql(exec_sql)
            elapsed = time.clock() - start
            # print(f"     执行时长:{elapsed}秒")
            log.info(f"     执行时长:{elapsed}秒")
        session.commit_trans()
    except Exception as e:
        try:
            session.rollback_trans()
            from huansi_utils.exception.exception import HSMessage
            message = HSMessage(e).format_message()[0]
            # print(f"出错子项执行脚本:\n{str(e)[:200]}...")
            log.error(f"出错子项执行脚本:\n{str(message)[:200]}...")
        except Exception as ex:
            # print(f"回滚或打印出现错误:{ex}")
            log.error(f"回滚或打印出现错误:{ex}")
    finally:
        session.close()


def remove_transaction(exec_sql):
    '''
    去掉sql语句中的begin tran,commit tran,rollback tran
    :param exec_sql:
    :return:
    '''
    exec_sql = exec_sql.replace('''BEGIN TRAN
BEGIN TRY''', '''--BEGIN TRAN
BEGIN TRY''')

    exec_sql = exec_sql.replace('''COMMIT TRAN
PRINT '执行成功，事务提交''', '''--COMMIT TRAN
PRINT '执行成功，事务提交''')

    exec_sql = exec_sql.replace('''ROLLBACK
RAISERROR('执行失败，事务回滚',16,1)''', '''--ROLLBACK
RAISERROR('执行失败，事务回滚',16,1)''')

    return exec_sql


def upgrade_db(file_path: str):
    '''
    升级DB脚本
    :param file_names:
    :return:
    '''
    with open(file_path, 'r', encoding='utf8') as f:
        file_names = f.read()
    file_name_list = file_names.split(',')

    log.info("**********开始升级脚本**********")

    from flask import g
    # 防止报错
    g.language = 'cn'
    # 初始化变量，同步sql也要写入日志
    g.sql = []
    prefix, package_path = find_package(current_app.import_name)
    for i, file_name in enumerate(file_name_list):
        # i从0开始计数
        i += 1
        file_path = os.path.join(package_path, 'DB', file_name)
        # if not os.path.exists(file_path):
        #     file_name = os.path.join(package_path, 'DB', file_name)

        # print(f"正在执行({i}/{len(file_name_list)}): {file_name}")
        log.info(f"正在执行({i}/{len(file_name_list)}): {file_name}")

        dtl = datetime.datetime.now()

        if os.path.exists(file_path):
            # 读取文件内容
            with open(file_path, 'r', encoding='UTF-8-sig') as f:
                sql = f.read()
            # 判断内容是否为空
            if sql:
                # 脚本升级
                if "--BATCHUPGRADE" in sql and "--ENDBATCHUPGRADE" in sql:
                    exec_upgrade_sql(sql, f"正在执行({i}/{len(file_name_list)}): {file_name}")
                else:
                    # 会解析错误，暂时用最原生的方法解决
                    exec_sql = sql.replace(':', "\:")
                    from huansi_utils.db.db import new_session
                    session = new_session(begin=True)
                    try:
                        # 每个事务增加等待
                        session.exec_sql("EXEC dbo.sppbWaitFor @sKey='db_upgrade'")
                        # 去掉sql语句中的begin tran,commit tran,rollback tran
                        exec_sql = remove_transaction(exec_sql)
                        session.exec_sql(exec_sql)
                        session.commit_trans()
                    except Exception as e:
                        from huansi_utils.exception.exception import HSMessage
                        message = HSMessage(e).format_message()[0]
                        session.rollback_trans()
                        log.error(f'执行报错：{message}')
                    finally:
                        session.close()
            else:
                # print(f"文件:{file_name} 内容为空。")
                log.error(f"文件:{file_name} 内容为空。")
        else:
            # print(f"文件:{file_name} 不存在。")
            log.error(f"文件:{file_name} 不存在。")
        times = (datetime.datetime.now() - dtl).total_seconds()
        # print(f"文件{file_name}执行完成，耗时{times}秒。")
        log.info(f"文件{file_name}执行完成，耗时{times}秒。")
    log.info("**********脚本升级结束**********")
