Spaces:
Runtime error
Runtime error
| # bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node | |
| # the rank of distributed node worker | |
| # If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. | |
| WORKER_RANK=${1:-'0'} | |
| PLATFORM=${2:-'shef'} | |
| YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} | |
| TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} | |
| MASTER_PROC_ADD=${5:-'127.0.0.1'} | |
| DIST_PORT=${6:-'39685'} | |
| # echo $PATH | |
| # export PATH=$PATH:./ | |
| echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" | |
| MAP_PROJ_DIR=$(pwd) | |
| echo $MAP_PROJ_DIR | |
| NNODS=1 | |
| MAX_TOKENS=1000000 # set for 80GB A100 batchsize | |
| NUM_WOKERS=0 | |
| run_command_prefix=' ' | |
| # Loading folders | |
| # 1. tsv files for audio paths | |
| # DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv | |
| DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest | |
| # 2. working folder for saving checkpoints and loading config files | |
| CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain | |
| # 3. clustering labels for training data | |
| LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset | |
| FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; | |
| SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ | |
| # set 75 for the RVQ-VAE model | |
| LABEL_RATE=75 | |
| case $YAML_NAME_WITHOUT_EXT in | |
| MERT_RVQ-VAE_CQT_95M) | |
| TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' | |
| NNODS=1 | |
| LABEL_RATE=75 | |
| MAX_TOKENS=1800000 | |
| ;; | |
| MERT_RVQ-VAE_CQT_95M_bestrq) | |
| TASK_LABELS_POSTFIX='["rq_0"]' | |
| NNODS=1 | |
| LABEL_RATE=75 | |
| MAX_TOKENS=1200000 | |
| ;; | |
| MERT_RVQ-VAE_CQT_330M) | |
| TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' | |
| NNODS=1 | |
| LABEL_RATE=75 | |
| NPROCES_PER_NODE=8 | |
| MAX_TOKENS=720000 | |
| ;; | |
| MERT_RVQ-VAE_CQT_330M_multinodes) | |
| TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' | |
| NNODS=4 | |
| LABEL_RATE=75 | |
| NPROCES_PER_NODE=8 | |
| MAX_TOKENS=600000 | |
| ;; | |
| MERT_RVQ-VAE_CQT_330M_multinodes_debug2node) | |
| TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' | |
| NNODS=2 | |
| LABEL_RATE=75 | |
| NPROCES_PER_NODE=8 | |
| MAX_TOKENS=600000 | |
| ;; | |
| MERT_RVQ-VAE_CQT_330M_multinodes_debug1node) | |
| TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' | |
| NNODS=1 | |
| LABEL_RATE=75 | |
| NPROCES_PER_NODE=8 | |
| MAX_TOKENS=600000 | |
| ;; | |
| *) | |
| echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" | |
| exit 1 | |
| ;; | |
| esac | |
| echo running $YAML_NAME_WITHOUT_EXT .. | |
| mkdir -p ${SAVE_DIR} | |
| echo "checkpoint save at: ${SAVE_DIR}" | |
| cd ${SAVE_DIR} | |
| DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` | |
| ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` | |
| echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" | |
| DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` | |
| CKPT_SAVE_DIR="${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT}" | |
| OMP_NUM_THREADS=6 ${run_command_prefix} \ | |
| python -u -m torch.distributed.launch --use_env \ | |
| --nproc_per_node=8 --nnodes=${NNODS} --node_rank=${INDEX} \ | |
| --master_addr=${CHIEF_IP} --master_port=25521 \ | |
| ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py -m \ | |
| --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT}\ | |
| common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ | |
| common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ | |
| task.data=${DATA_DIR}\ | |
| task.label_dir=${LABEL_DIR} \ | |
| task.labels=${TASK_LABELS_POSTFIX} \ | |
| dataset.num_workers=${NUM_WOKERS} \ | |
| dataset.max_tokens=${MAX_TOKENS} \ | |
| dataset.disable_validation=true \ | |
| model.label_rate=${LABEL_RATE}\ | |
| checkpoint.save_dir=${CKPT_SAVE_DIR} \ | |
| checkpoint.restore_file="checkpoint_last.pt" | |