Spaces:
Build error
Build error
| int main(int argc, char ** argv) { | |
| gpt_params params; | |
| params.seed = 42; | |
| params.n_threads = 4; | |
| params.repeat_last_n = 64; | |
| params.prompt = "The quick brown fox"; | |
| if (gpt_params_parse(argc, argv, params) == false) { | |
| return 1; | |
| } | |
| fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); | |
| if (params.n_predict < 0) { | |
| params.n_predict = 16; | |
| } | |
| auto lparams = llama_context_default_params(); | |
| lparams.n_ctx = params.n_ctx; | |
| lparams.seed = params.seed; | |
| lparams.f16_kv = params.memory_f16; | |
| lparams.use_mmap = params.use_mmap; | |
| lparams.use_mlock = params.use_mlock; | |
| auto n_past = 0; | |
| auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); | |
| // init | |
| auto ctx = llama_init_from_file(params.model.c_str(), lparams); | |
| auto tokens = std::vector<llama_token>(params.n_ctx); | |
| auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); | |
| if (n_prompt_tokens < 1) { | |
| fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); | |
| return 1; | |
| } | |
| // evaluate prompt | |
| llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); | |
| last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); | |
| n_past += n_prompt_tokens; | |
| const size_t state_size = llama_get_state_size(ctx); | |
| uint8_t * state_mem = new uint8_t[state_size]; | |
| // Save state (rng, logits, embedding and kv_cache) to file | |
| { | |
| FILE *fp_write = fopen("dump_state.bin", "wb"); | |
| llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file | |
| fwrite(state_mem, 1, state_size, fp_write); | |
| fclose(fp_write); | |
| } | |
| // save state (last tokens) | |
| const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data); | |
| const auto n_past_saved = n_past; | |
| // first run | |
| printf("\n%s", params.prompt.c_str()); | |
| for (auto i = 0; i < params.n_predict; i++) { | |
| auto logits = llama_get_logits(ctx); | |
| auto n_vocab = llama_n_vocab(ctx); | |
| std::vector<llama_token_data> candidates; | |
| candidates.reserve(n_vocab); | |
| for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | |
| candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | |
| } | |
| llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | |
| auto next_token = llama_sample_token(ctx, &candidates_p); | |
| auto next_token_str = llama_token_to_str(ctx, next_token); | |
| last_n_tokens_data.push_back(next_token); | |
| printf("%s", next_token_str); | |
| if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { | |
| fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | |
| return 1; | |
| } | |
| n_past += 1; | |
| } | |
| printf("\n\n"); | |
| // free old model | |
| llama_free(ctx); | |
| // load new model | |
| auto ctx2 = llama_init_from_file(params.model.c_str(), lparams); | |
| // Load state (rng, logits, embedding and kv_cache) from file | |
| { | |
| FILE *fp_read = fopen("dump_state.bin", "rb"); | |
| if (state_size != llama_get_state_size(ctx2)) { | |
| fprintf(stderr, "\n%s : failed to validate state size\n", __func__); | |
| return 1; | |
| } | |
| const size_t ret = fread(state_mem, 1, state_size, fp_read); | |
| if (ret != state_size) { | |
| fprintf(stderr, "\n%s : failed to read state\n", __func__); | |
| return 1; | |
| } | |
| llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file | |
| fclose(fp_read); | |
| } | |
| delete[] state_mem; | |
| // restore state (last tokens) | |
| last_n_tokens_data = last_n_tokens_data_saved; | |
| n_past = n_past_saved; | |
| // second run | |
| for (auto i = 0; i < params.n_predict; i++) { | |
| auto logits = llama_get_logits(ctx2); | |
| auto n_vocab = llama_n_vocab(ctx2); | |
| std::vector<llama_token_data> candidates; | |
| candidates.reserve(n_vocab); | |
| for (llama_token token_id = 0; token_id < n_vocab; token_id++) { | |
| candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); | |
| } | |
| llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; | |
| auto next_token = llama_sample_token(ctx2, &candidates_p); | |
| auto next_token_str = llama_token_to_str(ctx2, next_token); | |
| last_n_tokens_data.push_back(next_token); | |
| printf("%s", next_token_str); | |
| if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { | |
| fprintf(stderr, "\n%s : failed to evaluate\n", __func__); | |
| return 1; | |
| } | |
| n_past += 1; | |
| } | |
| printf("\n\n"); | |
| return 0; | |
| } | |