# from huansi_utils.huansi_util_common import debug
import datetime
import decimal
from collections import OrderedDict

from flask_sqlalchemy import BaseQuery
from sqlalchemy.engine import ResultProxy
from flask import g, current_app

from huansi_utils.app.apploader import logger
from sqlalchemy import text
from huansi_utils.exception.exception import HSSQLError, HSMessage
from huansi_utils.rpc_tools.redis_rpc import redis_log, redis_log_end, redis_log_start


class HSDBSession(object):
    def __init__(self, db_session, total_rollback_count=None):
        super().__init__()
        self.db_session = db_session
        self.trans_count = 0
        # 验证可Rollback的总次数
        self._total_rollback_count = total_rollback_count if total_rollback_count else 0
        self._rollback_count = 0

    def flush(self):
        self.db_session.flush()

    def begin_trans(self):
        if self._total_rollback_count > 0:
            if self._rollback_count > self._total_rollback_count:
                raise RuntimeError('事务已回滚，不能再启动事务')
        # SQLAlchemy的事务是自动启动的，不需要手动Begin
        if self.trans_count <= 0:
            self.trans_count = 0
        self.trans_count += 1
        # print('begin_tran:', self.trans_count)

    def commit_trans(self, close=False):
        if self._total_rollback_count > 0:
            if self._rollback_count > self._total_rollback_count:
                raise RuntimeError('事务已回滚，不能再启动事务')
        # print('before.commit_trans:', self.trans_count)
        self.db_session.flush()
        self.trans_count -= 1
        if self.trans_count == 0:
            self.db_session.commit()
        if close:
            self.close()

    def rollback_trans(self, close=False):
        # print('before.rollback_trans:', self.trans_count)
        self.trans_count = 0
        self.db_session.rollback()
        self._rollback_count += 1
        if close:
            self.close()

    def close(self):
        # print('close_session:', self.trans_count)
        in_trans = self.trans_count > 0
        if in_trans:
            self.rollback_trans()
        self.db_session.close()
        if in_trans:
            raise RuntimeError('当前还在事务中，已自动回滚并关闭连接')

        self.log_all_sql_handle()

    def execute(self, sql: str, **kwargs) -> ResultProxy:
        '''
        执行sql过程
        :param sql: str
        :param kwargs:
        :return:
        '''
        result = self.db_session.execute(text(sql), kwargs)
        return result

    def get_sql_from_query(self, query: BaseQuery) -> str:
        '''
        返回query中的sql语句
        :param query:
        :return: str
        '''
        if not query:
            return ''
        sql_state = query.statement.compile()
        sql = str(sql_state.statement)
        for key, value in sql_state.params.items():
            sql = sql.replace(':{}'.format(key), "'{}'".format(str(value)))
        return sql

    def log_sql(self, query: (BaseQuery, ResultProxy)) -> None:
        '''
        记录sql
        :param query:
        :return:
        '''
        return

    def log_redis(self, sql, kwargs, func_name):
        """
        往redis写日志
        :param sql: sql语句
        :param func_name: 执行函数的名称
        :return: 执行结果
        """
        real_sql = self.get_real_sql(e='', sql=sql, kwargs=kwargs)
        self.key = g.user.get('user_no', 'unauthorized') if hasattr(g, 'user') else 'unauthorized'
        self.tenant_code = g.user_info.get('tenant', 'huansi')
        app_code = current_app.config.get('APP_CODE')
        source = app_code + '.' + func_name if app_code else func_name
        return redis_log_start(type_='LOG:SQL', key=self.key, data=real_sql, source=source, tenant_code=self.tenant_code)

    def log_redis_end(self, res, status=1, result='success'):
        """
        回写redis日志
        :param res: 写入的返回结果
        :param status: 回写状态
        :param result: 回写结果
        """
        redis_log_end(status=status, result=result, **res)

    def exec_sql(self, sql, **kwargs):
        '''
        执行SQL
        :param sql: sql语句
        :param kwargs: 参数
        :return: SQL执行影响行数
        '''
        if not sql:
            return -1
        res = self.log_redis(sql, kwargs, func_name='exec_sql')
        try:
            # self.begin_trans()
            sqlEx = '{}\n{}'.format(sql, 'select _rowCount = @@RowCount')
            data = self.execute(sqlEx, **kwargs).first()
            if getattr(data, '_rowCount', None) is not None:
                rowcount = data._rowCount
            else:
                logger.error('EXEC过程中出现select语句并且报错了，{}'.format(sql))
                raise HSSQLError('EXEC过程中出现select语句并且报错了，{}'.format(sql))
            # self.commit_trans()
            self.log_redis_end(res)
            return rowcount
        except Exception as e:
            # 获取完整的sql
            sql = self.get_real_sql(e, kwargs, sql)
            logger.error('db_session.exec_sql出错，sql=', sql)
            logger.error(e)
            self.log_redis_end(res, status=0, result=str(e))
            self.exception_handle(e, sql)

    # 执行SQL
    def retrive_sql(self, sql, **kwargs):
        '''
        执行SQL，并返回单行数据
        :param sql: sql语句
        :param kwargs: 参数
        :return: 执行返回的单行数据
        '''
        if not sql:
            return None
        res = self.log_redis(sql, kwargs, func_name='retrive_sql')
        try:
            # self.begin_trans()
            data = self.execute(sql, **kwargs).first()
            # self.commit_trans()
            self.log_redis_end(res)
            return data
        except Exception as e:
            # 获取完整的sql
            sql = self.get_real_sql(e, kwargs, sql)
            logger.error('db_session.retrive_sql出错，sql=', sql)
            logger.error(e)
            self.log_redis_end(res, status=0, result=str(e))
            self.exception_handle(e, sql)

    def query_sql(self, sql, **kwargs):
        '''
        执行SQL，并返回多行数据
        :param sql: sql语句
        :param kwargs: 参数
        :return: 执行返回的多行数据
        '''
        if not sql:
            return None
        res = self.log_redis(sql, kwargs, func_name='query_sql')
        try:
            # self.begin_trans()
            data = self.execute(sql, **kwargs).fetchall()
            self.log_redis_end(res)
            # self.commit_trans()
            return data
        except Exception as e:
            # self.rollback_trans()
            # 获取完整的sql
            sql = self.get_real_sql(e, kwargs, sql)
            logger.error('db_session.query_sql出错，sql=', sql)
            logger.error(e)
            self.log_redis_end(res, status=0, result=str(e))
            self.exception_handle(e, sql)

    def exception_handle(self, e, sql):
        '''
        执行sql语句的异常报错处理
        :param e:
        :param sql:
        :return:
        '''
        # 'This result object does not return rows. It has been closed automatically.' 翻译成中文
        if len(e.args) >= 1:
            if e.args[0] == 'This result object does not return rows. It has been closed automatically.':
                raise HSSQLError('执行SQL必须要有结果返回,SQL语句==>{}'.format(sql))

        import pymssql
        # sqlalchemy最小单位的报错信息,要特殊处理
        if isinstance(e, pymssql.OperationalError):
            error_message_dict = {'error_title': '', 'error_detail': ''}
            for error_item in e.args:
                if isinstance(error_item, bytes):
                    error_item = error_item.decode('utf8')
                elif isinstance(error_item, int):
                    continue
                else:
                    error_item = str(error_item)
                error_message_dict['error_title'] += error_item
                # error_message_dict['error_detail'] += error_item
        else:
            error_message_dict = HSMessage(e).format_message()[0]

        error_title = error_message_dict.get('error_title', '')
        error_detail = error_message_dict.get('error_detail', '')
        if error_detail:
            from huansi_utils.common.string import after_string
            _error_detail = after_string(error_detail, '[ERROR]')
            if error_detail != _error_detail:
                error_detail = f"[ERROR]{_error_detail}"
            else:
                error_detail = ''

        # 去除该类型错误信息
        # DB-Lib error message 20018, severity 16: General SQL Server error: Check messages from the SQL Server
        _l2 = error_title.split('DB-Lib error message')
        error_title = _l2[0]
        if len(_l2) == 2:
            error_detail += '\nDB-Lib error message{}'.format(_l2[1])

        raise HSSQLError('{}\n,SQL语句==>{}\n{}'.format(error_title, sql, error_detail))

    def get_real_sql(self, e, kwargs, sql):
        '''
        获取完整的报错sql
        :param e:
        :param kwargs:
        :param sql:
        :return:
        '''
        _sql = getattr(e, 'statement', None)
        if _sql:
            for key, value in kwargs.items():
                _sql = _sql.replace('%({})s'.format(key), "'{}'".format(str(value)))
            sql = _sql
        else:
            for key,value in kwargs.items():
                sql = sql.replace(':{}'.format(key), "'{}'".format(str(value)))
        return sql

    def query_sql_to_many_set(self, sql, **kwargs):
        res = self.log_redis(sql, kwargs, func_name='query_sql_to_many_set')
        try:
            query_data = self.execute(sql, **kwargs)
            result = {}
            count = 1
            while 1:
                name = 'set{}'.format(count)
                data_set = query_data.cursor.fetchall()
                fields = [c[0] for c in query_data.cursor.description]

                _data = []
                for date_set_item in data_set:
                    _data_set_item = [self.format_data(item) for item in date_set_item]
                    _data_dict = OrderedDict(map(lambda x, y: [x, y], fields, _data_set_item))

                    if _data_dict.get('ROWSTAT'):
                        # 去除字段ROWSTAT
                        _data_dict.pop('ROWSTAT')

                    if _data_dict.get('_TableName'):
                        name = _data_dict.get('_TableName')
                        # 去除字段_TableName
                        _data_dict.pop('_TableName')

                    _data.append(_data_dict)
                # _data如果为空，去掉这个实体集
                if not _data:
                    # 并且是最后一个实体集
                    if query_data.cursor.nextset() is None:
                        result[name] = _data
                        result["set_count"] = count
                        break

                    count += 1
                    continue

                result[name] = _data
                result["set_count"] = count

                if query_data.cursor.nextset() is None:
                    break

                count += 1
            self.log_redis_end(res)
            return result
        except Exception as e:
            # 获取完整的sql
            sql = self.get_real_sql(e, kwargs, sql)
            logger.error('db_session.query_sql出错，sql=', sql)
            logger.error(e)
            self.log_redis_end(res, status=0, result=str(e))
            self.exception_handle(e, sql)

    @staticmethod
    def format_data(data):
        '''
        格式化数据
        :param data: 字段值
        :return: 格式化后的数据
        '''
        if isinstance(data, bool):
            return data
        elif isinstance(data, datetime.datetime):
            data = data.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]  # SQLserver数据库中毫秒是3位，日期格式;2015-05-12 11:13:58.543
        elif isinstance(data, datetime.date):
            data = data.strftime("%Y-%m-%d")
        elif isinstance(data, decimal.Decimal):
            data = float(data)
        # elif isinstance(data, int) and len(str(data)) == 19:
        elif isinstance(data, int):
            data = str(data)
        elif isinstance(data, bytes):
            data = str(data)

        return data

    def log_all_sql(self, data):
        '''
        记录所有日志
        :return:
        '''
        # 解析数据
        sql_list = data.get('sql_list')
        user = data.get('user')
        request_url = data.get('request_url')
        request_info = data.get('request_info')
        app_code = data.get('app_code')
        proxy_url = data.get('proxy_url')
        global_sql_error = data.get('global_sql_error')
        tenant_code = data.get('tenant_code')
        # 记录sql

        if sql_list:
            try:
                # 生成的SQL语句
                sql_str = '\n'.join(sql_list).replace("'", "''").replace(":", "\:")
                # 如果是记录报错的，不记录
                if "EXEC dbo.spsmException_Log @url=" in sql_str:
                    return
                proxy_url = proxy_url if ':' in proxy_url else proxy_url + ':80'
                ip, port = proxy_url.split(':')
                if not self._test_connect_port(ip=ip, port=port):
                    logger.error('nginx代理地址连不上，请检查配置')
                    return
                # 用户
                request_user = user.get('user_no', 'unauthorized')

                data_dict = {
                    "delay_type": 0,
                    "db_name": "",
                    "data": {
                        "type": "global_sql_log",
                        "param": "",
                        "data": {
                            "pbGlobalSqlLog": [
                                {
                                    "_action": 1,
                                    "_primary_key": "",
                                    "app": app_code,
                                    "sql": sql_str,
                                    "request_url": request_url,
                                    "request_info": request_info,
                                    "request_user": request_user,
                                    'tenant_code': tenant_code,
                                    'error': global_sql_error.get('error'),
                                    'exception': global_sql_error.get('exception'),
                                    'error_sql': global_sql_error.get('error_sql'),
                                    'status': 0 if global_sql_error else 1
                                }
                            ]
                        }
                    }}
                redis_log(type_='LOG:GLOBALSQL', key=request_user, data=data_dict, source=app_code,
                          proxy_url=proxy_url, tenant_code=tenant_code)
                # from flask_app import global_app
                # from hs_rpc import rpc_request_invoke
                # import json
                # with global_app.app_context():
                #     response = rpc_request_invoke('async_save_table', message={"delay_type": 0, "json_data": data_dict},
                #                                   app='TOOL')
                #     logger.info('tool_rpc reponse:' + str(response))
            except Exception as e:
                import traceback
                logger.error('全局日志报错:' + traceback.format_exc())

    def _test_connect_port(self, ip, port):
        '''
        检查端口是否开放
        :param ip:
        :param port:
        :return:
        '''
        import socket
        sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if not port:
            port = 80
        try:
            sk.connect((ip, int(port)))
            sk.shutdown(2)
            return True
        except:
            return False
        finally:
            sk.close()

    def log_all_sql_handle(self):
        '''
        全局日志处理
        :return:
        '''
        # 是否开启全局日志记录
        import threading
        from flask import g, request, current_app
        # 是否打开rpc开关
        hsrpc_status = current_app.config.get('HSRPC_STATUS', False)
        if not hsrpc_status:
            return
        # 是否记录全局日志
        is_logger = getattr(g, 'is_logger', True)
        # global_sql_log
        global_sql_log = current_app.config.get('GLOBAL_SQL_LOG', False)

        # g对象设置开关，用于请求时候开启和关闭
        # 环境变量设置开关，用于程序启动的时候开启和关闭
        if not is_logger or not global_sql_log:
            return

        # 全局sql
        sql_list = getattr(g, 'sql', [])
        # 全局错误日志
        global_sql_error = g.global_sql_error if hasattr(g, 'global_sql_error') else {}
        # 用户信息
        user = getattr(g, 'user', {})
        try:
            # 请求的URL
            request_url = request.url
            # 请求数据包
            request_info = str(request.get_json()).replace("'", "''").replace(":", "\:")
        except:
            # 请求的URL
            request_url = ''
            # 请求数据包
            request_info = ''

        # app名称定义
        app_code = current_app.config.get('APP_CODE', '未定义')
        # proxy_url
        proxy_url = current_app.config.get('RPC_PROXY_URL', '47.110.145.204:59169')
        # 租户
        tenant_code = g.user_info.get('tenant', 'huansi')

        data = {
            "is_logger": is_logger,
            "sql_list": sql_list,
            "user": user,
            "request_url": request_url,
            "request_info": request_info,
            "app_code": app_code,
            "proxy_url": proxy_url,
            "global_sql_log": global_sql_log,
            "global_sql_error": global_sql_error,
            "tenant_code": tenant_code
        }
        t = threading.Thread(target=self.log_all_sql, args=(data,))
        t.start()

        # 重置g.sql
        g.sql = []
