#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Project      : MeUtils.
# @File         : ann_v1
# @Time         : 2021/4/17 12:22 下午
# @Author       : yuanjie
# @WeChat       : 313303303
# @Software     : PyCharm
# @Description  :
import copy
from meutils.pipe import *
from milvus import Milvus, DataType, MetricType, IndexType


class ANN(object):

    def __init__(self, host='10.46.242.23', port='19530', pool="SingletonThread", show_info=False):
        self.host = host
        self.client = Milvus(
            host, port,
            handler='GRPC' if port == '19530' else 'HTTP',  # 19530, 19121
            pool=pool,  # 线程池
        )
        logger.info(f"ServerVersion {self.client.server_version()}")
        logger.info(f"ClientVersion {self.client.client_version()}")

    def __getattr__(self, collection_name) -> Collection:
        return Collection(collection_name, self.client)

    def create_collection(self, collection_name='TEST',
                          dim=768,
                          index_file_size=1024,
                          metric_type='IP',
                          index_type='IVF_FLAT',
                          nlist=2048,
                          partitions=None,
                          overwrite=True):

        if self.client.has_collection(collection_name):
            if overwrite:
                logger.warning(f"{collection_name} already exists! to drop.")
                self.client.drop_collection(collection_name, timeout=300)
            else:
                return f"{collection_name} already exists!"

        # 建表
        params = {
            'collection_name': collection_name,
            'dimension': dim,
            'metric_type': MetricType.__getattr__(metric_type),
            'index_type': IndexType.__getattr__(index_type),
            'index_file_size': index_file_size,
            'index_param': {'nlist': nlist}

        }

        self.client.create_collection(params)
        self.client.create_index(collection_name, params['index_type'], params['index_param'])

        # 分区：删除分区就不需要重新建表
        if partitions is not None:
            for part in partitions:
                self.client.create_partition(collection_name, partition_tag=part)

        logger.info(f"{self.client.get_collection_info(collection_name)}")


class Collection(object):

    def __init__(self, name=None, client=None):
        self.name = name
        self.client = client

    def __str__(self):
        _, ok = self.client.has_collection(self.name)
        if not ok:
            logger.warning(f"{self.name}  doesn't exist")
        return f"Collection({self.name})"

    def insert(self, df_vec, batch_size=100000, partition_tag=None):
        assert 'vector' in df_vec  # todo: 多进程插入

        n = len(df_vec) // batch_size + 1
        dfs = np.array_split(df_vec, n)

        s = 0
        ids = np.zeros(len(df_vec))
        for df in tqdm(dfs):
            ids[s:s + len(df)] = self.client.insert(
                self.name,
                df['vector'].tolist(),
                ids=df['id'].tolist() if 'id' in df_vec else None,
                partition_tag=partition_tag
            )
            s = len(df)

        return ids

    def search(self, topk, vectors, nprobe=64, partition_tags=None):
        status, results = self.client.search(
            self.name, topk, vectors,
            partition_tags=partition_tags,
            params={"nprobe": nprobe},
        )
        return results

    def create_partition(self, partition_tag):
        return self.client.drop_partition(self.name, partition_tag)

    def drop_partition(self, partition_tag):
        return self.client.drop_partition(self.name, partition_tag)

    def get_entity_by_id(self, ids, fields=None):
        return self.client.get_entity_by_id(self.name, ids, fields)

    def delete_entity_by_id(self, ids):
        return self.client.delete_entity_by_id(self.name, ids)

    @property
    def count(self):
        return self.client.count_entities(self.name)

    @property
    def info(self):
        return self.client.get_collection_info(self.name)[1]

    @property
    def stats(self):
        return self.client.get_collection_stats(self.name)[1]

    @property
    def partitions(self):
        return self.client.list_partitions(self.name)
