Commit 
							
							·
						
						3d06b91
	
1
								Parent(s):
							
							9ddbf48
								
Add unfat config
Browse files- .gitignore +5 -0
 - main.py +157 -0
 - pyproject.toml +10 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,5 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            /output
         
     | 
| 2 | 
         
            +
            __pycache__
         
     | 
| 3 | 
         
            +
            .vim
         
     | 
| 4 | 
         
            +
            /dist
         
     | 
| 5 | 
         
            +
            unfat.egg-info
         
     | 
    	
        main.py
    ADDED
    
    | 
         @@ -0,0 +1,157 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import re
         
     | 
| 3 | 
         
            +
            from typing import cast, Any
         
     | 
| 4 | 
         
            +
            from datasets import load_dataset, Dataset as HfDataset
         
     | 
| 5 | 
         
            +
            from unfat.extract import Extractor
         
     | 
| 6 | 
         
            +
            from unfat.client import OpenAiCompatClient
         
     | 
| 7 | 
         
            +
            from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit
         
     | 
| 8 | 
         
            +
            from unfat.together import llama_3_1_70b_together
         
     | 
| 9 | 
         
            +
            from unfat.lora import LoraSettings
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def gen_prompts(
         
     | 
| 12 | 
         
            +
                ds_name: str,
         
     | 
| 13 | 
         
            +
                text_field: str,
         
     | 
| 14 | 
         
            +
                start_regex: re.Pattern | None = None,
         
     | 
| 15 | 
         
            +
                end_regex: re.Pattern | None = None,
         
     | 
| 16 | 
         
            +
            ):
         
     | 
| 17 | 
         
            +
                ds = cast(HfDataset, load_dataset(ds_name, split="train"))
         
     | 
| 18 | 
         
            +
                def items():
         
     | 
| 19 | 
         
            +
                    for row in ds:
         
     | 
| 20 | 
         
            +
                        casted = cast(dict[Any, Any], row)
         
     | 
| 21 | 
         
            +
                        text = casted[text_field]
         
     | 
| 22 | 
         
            +
                        if start_regex and end_regex:
         
     | 
| 23 | 
         
            +
                            yield end_regex.sub("", start_regex.sub("", text))
         
     | 
| 24 | 
         
            +
                        elif start_regex:
         
     | 
| 25 | 
         
            +
                            yield start_regex.sub("", text)
         
     | 
| 26 | 
         
            +
                        elif end_regex:
         
     | 
| 27 | 
         
            +
                            yield end_regex.sub("", text)
         
     | 
| 28 | 
         
            +
                        else:
         
     | 
| 29 | 
         
            +
                            yield text
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                return Prompts(
         
     | 
| 32 | 
         
            +
                    output_path=f"hub/{ds_name}.jsonl",
         
     | 
| 33 | 
         
            +
                    count=lambda: len(ds),
         
     | 
| 34 | 
         
            +
                    items=items,
         
     | 
| 35 | 
         
            +
                )
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def extract_prompts_from_convos(
         
     | 
| 38 | 
         
            +
                ds_name: str,
         
     | 
| 39 | 
         
            +
                messages_field: str,
         
     | 
| 40 | 
         
            +
                role_field: str,
         
     | 
| 41 | 
         
            +
                content_field: str,
         
     | 
| 42 | 
         
            +
                user_role: str,
         
     | 
| 43 | 
         
            +
            ):
         
     | 
| 44 | 
         
            +
                ds = cast(HfDataset, load_dataset(ds_name, split="train"))
         
     | 
| 45 | 
         
            +
                def items():
         
     | 
| 46 | 
         
            +
                    for row in ds:
         
     | 
| 47 | 
         
            +
                        casted = cast(dict[Any, Any], row)
         
     | 
| 48 | 
         
            +
                        for message in casted[messages_field]:
         
     | 
| 49 | 
         
            +
                            if message[role_field] == user_role:
         
     | 
| 50 | 
         
            +
                                yield message[content_field]
         
     | 
| 51 | 
         
            +
                                break
         
     | 
| 52 | 
         
            +
                return Prompts(
         
     | 
| 53 | 
         
            +
                    output_path=f"hub/{ds_name}.jsonl",
         
     | 
| 54 | 
         
            +
                    count=lambda: len(ds),
         
     | 
| 55 | 
         
            +
                    items=items,
         
     | 
| 56 | 
         
            +
                )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def main():
         
     | 
| 59 | 
         
            +
                output_dir = "output"
         
     | 
| 60 | 
         
            +
                rp_english = extract_prompts_from_convos(
         
     | 
| 61 | 
         
            +
                    ds_name="OdiaGenAI/roleplay_english",
         
     | 
| 62 | 
         
            +
                    messages_field="conversations",
         
     | 
| 63 | 
         
            +
                    role_field="from",
         
     | 
| 64 | 
         
            +
                    content_field="value",
         
     | 
| 65 | 
         
            +
                    user_role="user",
         
     | 
| 66 | 
         
            +
                )
         
     | 
| 67 | 
         
            +
                bluemoon = extract_prompts_from_convos(
         
     | 
| 68 | 
         
            +
                    ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp",
         
     | 
| 69 | 
         
            +
                    messages_field="conversations",
         
     | 
| 70 | 
         
            +
                    role_field="from",
         
     | 
| 71 | 
         
            +
                    content_field="value",
         
     | 
| 72 | 
         
            +
                    user_role="human",
         
     | 
| 73 | 
         
            +
                )
         
     | 
| 74 | 
         
            +
                roleplay_prompts = gen_prompts(
         
     | 
| 75 | 
         
            +
                    ds_name="AlekseyKorshuk/roleplay-io",
         
     | 
| 76 | 
         
            +
                    text_field="input_text",
         
     | 
| 77 | 
         
            +
                    start_regex=re.compile(r'^User: '),
         
     | 
| 78 | 
         
            +
                    end_regex=re.compile(r'Bot:\s*$'),
         
     | 
| 79 | 
         
            +
                )
         
     | 
| 80 | 
         
            +
                roleplay_instr_prompts = gen_prompts(
         
     | 
| 81 | 
         
            +
                    ds_name="iamketan25/roleplay-instructions-dataset",
         
     | 
| 82 | 
         
            +
                    text_field="prompt",
         
     | 
| 83 | 
         
            +
                    start_regex=re.compile(r'^Human: '),
         
     | 
| 84 | 
         
            +
                    end_regex=re.compile(r'Assistant:\s*$'),
         
     | 
| 85 | 
         
            +
                )
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                extractor = Extractor(
         
     | 
| 88 | 
         
            +
                    max_concurrent=50,
         
     | 
| 89 | 
         
            +
                    output_dir=output_dir,
         
     | 
| 90 | 
         
            +
                    client=OpenAiCompatClient(
         
     | 
| 91 | 
         
            +
                        base_url="https://glhf.chat/api/openai/v1",
         
     | 
| 92 | 
         
            +
                        api_key=os.environ["GLHF_API_KEY"],
         
     | 
| 93 | 
         
            +
                        model="hf:TheDrummer/Behemoth-123B-v1.2",
         
     | 
| 94 | 
         
            +
                        retries=20,
         
     | 
| 95 | 
         
            +
                    ),
         
     | 
| 96 | 
         
            +
                    dataset=Dataset(
         
     | 
| 97 | 
         
            +
                        train=[
         
     | 
| 98 | 
         
            +
                            hub_prompts(
         
     | 
| 99 | 
         
            +
                                name="mlabonne/harmful_behaviors",
         
     | 
| 100 | 
         
            +
                                text_field="text",
         
     | 
| 101 | 
         
            +
                                split="train",
         
     | 
| 102 | 
         
            +
                            ),
         
     | 
| 103 | 
         
            +
                            roleplay_instr_prompts,
         
     | 
| 104 | 
         
            +
                            roleplay_prompts,
         
     | 
| 105 | 
         
            +
                            rp_english,
         
     | 
| 106 | 
         
            +
                            bluemoon,
         
     | 
| 107 | 
         
            +
                            hub_prompts(
         
     | 
| 108 | 
         
            +
                                name="TheDrummer/AmoralQA-v2",
         
     | 
| 109 | 
         
            +
                                text_field="prompt",
         
     | 
| 110 | 
         
            +
                                split="train",
         
     | 
| 111 | 
         
            +
                            ),
         
     | 
| 112 | 
         
            +
                            hub_prompts(
         
     | 
| 113 | 
         
            +
                                name="vicgalle/OpenHermesPreferences-roleplay",
         
     | 
| 114 | 
         
            +
                                text_field="prompt",
         
     | 
| 115 | 
         
            +
                                split="train",
         
     | 
| 116 | 
         
            +
                            ),
         
     | 
| 117 | 
         
            +
                            hub_prompts(
         
     | 
| 118 | 
         
            +
                                name="mrcuddle/DPO_Pairs_Roleplay-Alpaca",
         
     | 
| 119 | 
         
            +
                                text_field="prompt",
         
     | 
| 120 | 
         
            +
                                split="train",
         
     | 
| 121 | 
         
            +
                            ),
         
     | 
| 122 | 
         
            +
                            hub_prompts(
         
     | 
| 123 | 
         
            +
                                name="ResplendentAI/theory_of_mind_fixed_output",
         
     | 
| 124 | 
         
            +
                                text_field="instruction",
         
     | 
| 125 | 
         
            +
                                split="train",
         
     | 
| 126 | 
         
            +
                            ),
         
     | 
| 127 | 
         
            +
                            hub_prompts(
         
     | 
| 128 | 
         
            +
                                name="mlabonne/harmless_alpaca",
         
     | 
| 129 | 
         
            +
                                text_field="text",
         
     | 
| 130 | 
         
            +
                                split=HubSplit(name="train", max_rows=1000),
         
     | 
| 131 | 
         
            +
                            ),
         
     | 
| 132 | 
         
            +
                        ],
         
     | 
| 133 | 
         
            +
                    ),
         
     | 
| 134 | 
         
            +
                )
         
     | 
| 135 | 
         
            +
                extractor.run()
         
     | 
| 136 | 
         
            +
                dataset = extractor.output_dataset()
         
     | 
| 137 | 
         
            +
                together_config = llama_3_1_70b_together(
         
     | 
| 138 | 
         
            +
                    output_dir=output_dir,
         
     | 
| 139 | 
         
            +
                    dataset=dataset,
         
     | 
| 140 | 
         
            +
                    api_key=os.environ["TOGETHER_API_KEY"],
         
     | 
| 141 | 
         
            +
                    settings=LoraSettings(
         
     | 
| 142 | 
         
            +
                        rank=32,
         
     | 
| 143 | 
         
            +
                        alpha=16,
         
     | 
| 144 | 
         
            +
                        dropout=0.01,
         
     | 
| 145 | 
         
            +
                        num_epochs=2,
         
     | 
| 146 | 
         
            +
                        learning_rate=4e-4,
         
     | 
| 147 | 
         
            +
                        evals_per_epoch=0,
         
     | 
| 148 | 
         
            +
                        wandb_project="behemoth-distill",
         
     | 
| 149 | 
         
            +
                        wandb_api_key=os.environ["WANDB_API_KEY"],
         
     | 
| 150 | 
         
            +
                    )
         
     | 
| 151 | 
         
            +
                )
         
     | 
| 152 | 
         
            +
                files = together_config.upload_files()
         
     | 
| 153 | 
         
            +
                together_config.finetune(files)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 157 | 
         
            +
                main()
         
     | 
    	
        pyproject.toml
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            [project]
         
     | 
| 2 | 
         
            +
            name = "behemoth-lora"
         
     | 
| 3 | 
         
            +
            version = "0.1.0"
         
     | 
| 4 | 
         
            +
            description = "Add your description here"
         
     | 
| 5 | 
         
            +
            readme = "README.md"
         
     | 
| 6 | 
         
            +
            requires-python = ">=3.11"
         
     | 
| 7 | 
         
            +
            dependencies = [
         
     | 
| 8 | 
         
            +
                "datasets>=3.3.2",
         
     | 
| 9 | 
         
            +
                "unfat>=0.0.13",
         
     | 
| 10 | 
         
            +
            ]
         
     |