from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.engine import Engine
import logging


class MultiDBEngineManager:
    """
    多数据库引擎管理,可以通过名字获取对应连接字符串的引擎
    """

    def __init__(self, sqlalchemy: SQLAlchemy):
        """
        初始化一个多数据库引擎管理
        :param sqlalchemy 一个SQLAlchemy对象
        """
        self.sqlalchemy = sqlalchemy
        self.app = sqlalchemy.app
        self.get_connect_string_func = None  # type:function

    def get_connect_string_event(self, func):
        """
        这是一个装饰器
        获取连接字符串事件装饰方法即可
        :param func:
        :return:
        """
        self.get_connect_string_func = func
        return func

    def __get_bind_engine(self, key: str) -> Engine:
        """
        根据名字获取连接字符串对应的SqlAlchemy的Engine,用户在Session里切换不同的数据库
        :param key:
        :return:
        """
        # 方式一,直接create_engin 就是bind,当然也可以用字典缓存
        # return create_engine(dict[name])
        # 方式二,使用flask-sqlalchemy的bind配置,利用flash自动缓存engin
        bind_keys = "SQLALCHEMY_BINDS"
        binds = self.app.config[bind_keys]
        if not binds:
            binds = dict()
            self.app.config[bind_keys] = binds
        if key in binds.keys():
            return self.sqlalchemy.get_engine(bind=key)
        connect_string = self.get_connect_string_func(key)
        if connect_string:
            self.app.config[bind_keys][key] = connect_string
            return self.sqlalchemy.get_engine(bind=key)
        raise Exception('找不对{}对应的数据库连接'.format(key))

    def change_database(self, key, session):
        """
        变更数据库
        :param key: 数据库的key
        :param session: 当前Session
        :return:
        """
        # 如果没有装饰过多数据库连接,则跳过
        if not key:
            return
        if not self.get_connect_string_func:
            logging.debug('由于没有监听多数据库连接事件,将不启用多数据库')
            return
        session.bind = self.__get_bind_engine(key)

        from sqlalchemy.orm import sessionmaker
        from sqlalchemy.util import ScopedRegistry
        from flask.globals import _app_ctx_stack

        session_factory = sessionmaker(bind=session.bind, autoflush=False)
        session.session_factory = session_factory
        session.registry = ScopedRegistry(session_factory, _app_ctx_stack.__ident_func__)
