import os
import yaml
from macls.trainer import StealTrainer
from macls.predict import MAClsPredictor


def create_UrbanSound8K_list(audio_path, metadata_path, list_path):
    sound_sum = 0
    os.makedirs(list_path, exist_ok=True)

    f_train = open(os.path.join(list_path, 'train_list.txt'), 'w', encoding='utf-8')
    f_test = open(os.path.join(list_path, 'test_list.txt'), 'w', encoding='utf-8')
    f_label = open(os.path.join(list_path, 'label_list.txt'), 'w', encoding='utf-8')

    with open(metadata_path) as f:
        lines = f.readlines()

    labels = {}
    for i, line in enumerate(lines):
        if i == 0:
            continue
        data = line.replace('\n', '').split(',')
        class_id = int(data[6])
        if class_id not in labels.keys():
            labels[class_id] = data[-1]
        sound_path = os.path.join(audio_path, f'fold{data[5]}', data[0]).replace('\\', '/')
        
        f_train.write(f'{sound_path}\t{data[6]}\n')
        if sound_sum %50 == 0:
            f_test.write(f'{sound_path}\t{data[6]}\n')
        sound_sum += 1

    for i in range(len(labels)):
        f_label.write(f'{labels[i]}\n')

    f_label.close()
    f_test.close()
    f_train.close()


def run_steal_training(dataset_path,
                       save_model_path,
                       blackbox_model_path,
                       configs='configs/cam++.yml',
                       data_augment_configs='configs/augmentation.yml',
                       log_dir='log_steal',
                       use_gpu=False,
                       lr=None,
                       max_epoch=None,
                       batch_size=None):
    """
    dataset_path: 数据库根目录 (/home/iceboy/Shtrain/UrbanSound8K)
    save_model_path: 模型保存目录 (models_steal)
    blackbox_model_path: 黑盒模型目录 (/home/iceboy/Shtrain/AudioClassification-Pytorch/models/ERes2Net_Fbank/best_model/)
    configs: 配置文件路径 (默认 configs/cam++.yml)
    data_augment_configs: 数据增强配置文件
    log_dir: 日志目录
    use_gpu: 是否使用GPU
    lr: 学习率 (覆盖配置文件)
    max_epoch: 最大训练轮数 (覆盖配置文件)
    batch_size: 批量大小 (覆盖配置文件)
    """

    # 1. 创建数据列表
    create_UrbanSound8K_list(
        audio_path=os.path.join(dataset_path, 'audio'),
        metadata_path=os.path.join(dataset_path, 'metadata/UrbanSound8K.csv'),
        list_path='dataset'
    )

    # 2. 加载并修改配置文件
    with open(configs, 'r', encoding='utf-8') as f:
        cfg = yaml.safe_load(f)

    if lr is not None:
        cfg['optimizer_conf']['optimizer_args']['lr'] = float(lr)
        cfg['optimizer_conf']['scheduler_args']['max_lr'] = float(lr)
    if max_epoch is not None:
        cfg['train_conf']['max_epoch'] = int(max_epoch)
    if batch_size is not None:
        cfg['dataset_conf']['dataLoader']['batch_size'] = int(batch_size)

    # 保存修改后的配置到临时文件
    os.makedirs('configs/tmp', exist_ok=True)
    tmp_config_path = 'configs/tmp/cam_tmp.yml'
    with open(tmp_config_path, 'w', encoding='utf-8') as f:
        yaml.dump(cfg, f, allow_unicode=True)

    # 3. 构建训练器
    trainer = StealTrainer(configs=tmp_config_path,
                           use_gpu=use_gpu,
                           data_augment_configs=data_augment_configs,
                           overwrites=None)

    # 4. 构建黑盒预测器
    predictor_steal = MAClsPredictor(
        configs='/home/iceboy/Shtrain/AudioClassification-Pytorch/configs/eres2net.yml',
        model_path=blackbox_model_path,
        use_gpu=use_gpu
    )

    # 5. 开始训练
    logs=trainer.train(save_model_path=save_model_path,
                  log_dir=log_dir,
                  resume_model=None,
                  pretrained_model=None,
                  predictor_steal=predictor_steal)
    return logs

# -------------------------
# 使用示例
# -------------------------
if __name__ == '__main__':
    logs=run_steal_training(
        dataset_path='/home/iceboy/Shtrain/UrbanSound8K',
        save_model_path='models_steal2',
        blackbox_model_path='/home/iceboy/Shtrain/AudioClassification-Pytorch/models/ERes2Net_Fbank/best_model/',
        use_gpu=False,
        lr=0.001,
        max_epoch=80,
        batch_size=32
    )
    print(logs)