import functools
from macls.trainer import StealTrainer
from macls.utils.utils import print_arguments
from macls.predict import MAClsPredictor


def run_steal_attack(
        model_path='/home/iceboy/Shtrain/AudioClassification-Pytorch/models/ERes2Net_Fbank/best_model/',
        resume_model='/home/iceboy/Shtrain/AudioClassification-Pytorch/models_steal/CAMPPlus_Fbank/best_model',
        attack_iters=10,
        configs='configs/cam++.yml',
        data_augment_configs='configs/augmentation.yml',
        local_rank=0,
        use_gpu=False,
        save_model_path='models_steal2/',
        log_dir='log_steal2/',
        pretrained_model=None,
        overwrites=None
):
    """
    黑盒模型参数窃取攻击实验

    Args:
        model_path (str): 黑盒模型存储的位置
        resume_model (str): 替代模型存储的位置
        attack_iters (int): 攻击迭代次数
        configs (str): 配置文件路径
        data_augment_configs (str): 数据增强配置文件路径
        local_rank (int): 多卡训练需要的参数
        use_gpu (bool): 是否使用GPU
        save_model_path (str): 模型保存路径
        log_dir (str): 日志保存路径
        pretrained_model (str|None): 预训练模型路径
        overwrites (str|None): 覆盖配置参数
    """

    # 把参数打包成对象
    class Args:
        pass

    args = Args()
    args.configs = configs
    args.data_augment_configs = data_augment_configs
    args.local_rank = local_rank
    args.use_gpu = use_gpu
    args.save_model_path = save_model_path
    args.log_dir = log_dir
    args.resume_model = None  # 注意：这里的 resume_model 在 evaluate 中传
    args.pretrained_model = pretrained_model
    args.overwrites = overwrites

    print_arguments(args=args)

    # 获取训练器
    trainer = StealTrainer(configs=args.configs,
                           use_gpu=args.use_gpu,
                           data_augment_configs=args.data_augment_configs,
                           overwrites=args.overwrites)

    # 黑盒预测器
    predictor_steal = MAClsPredictor(
        configs='/home/iceboy/Shtrain/AudioClassification-Pytorch/configs/eres2net.yml',
        model_path=model_path,
        use_gpu=use_gpu
    )


    trainer.evaluate(resume_model=resume_model,
                     predictor_steal=predictor_steal,
                     attack_iters=attack_iters)

run_steal_attack(
    model_path='/home/iceboy/Shtrain/AudioClassification-Pytorch/models/ERes2Net_Fbank/best_model/',
    resume_model='/home/iceboy/Shtrain/AudioClassification-Pytorch/models_steal/CAMPPlus_Fbank/best_model',
    attack_iters=20,
    configs='/home/iceboy/Shtrain/AudioClassification-Pytorch/macls/configs/cam++.yml',
    use_gpu=False,
    save_model_path='models_attack/',
    log_dir='logs_attack/'
)
