Update README.md
Browse files
README.md
CHANGED
|
@@ -43,7 +43,14 @@ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
|
| 43 |
# Download the vocab_remi.pkl file
|
| 44 |
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Load the tokenizer dictionary
|
| 49 |
with open(tokenizer_path, "rb") as f:
|
|
@@ -57,12 +64,20 @@ model.load_state_dict(torch.load(model_path, map_location=device))
|
|
| 57 |
model.eval()
|
| 58 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
|
|
|
|
|
|
|
| 61 |
inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
| 62 |
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
| 63 |
input_ids = input_ids.to(device)
|
| 64 |
attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
| 65 |
attention_mask = attention_mask.to(device)
|
|
|
|
|
|
|
| 66 |
output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
|
| 67 |
output_list = output[0].tolist()
|
| 68 |
generated_midi = r_tokenizer.decode(output_list)
|
|
|
|
| 43 |
# Download the vocab_remi.pkl file
|
| 44 |
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
|
| 45 |
|
| 46 |
+
if torch.cuda.is_available():
|
| 47 |
+
device = 'cuda'
|
| 48 |
+
elif torch.backends.mps.is_available():
|
| 49 |
+
device = 'mps'
|
| 50 |
+
else:
|
| 51 |
+
device = 'cpu'
|
| 52 |
+
|
| 53 |
+
print(f"Using device: {device}")
|
| 54 |
|
| 55 |
# Load the tokenizer dictionary
|
| 56 |
with open(tokenizer_path, "rb") as f:
|
|
|
|
| 64 |
model.eval()
|
| 65 |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
|
| 66 |
|
| 67 |
+
print('Model loaded.')
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Enter the text prompt and tokenize it
|
| 71 |
src = "A melodic electronic song with ambient elements, featuring piano, acoustic guitar, alto saxophone, string ensemble, and electric bass. Set in G minor with a 4/4 time signature, it moves at a lively Presto tempo. The composition evokes a blend of relaxation and darkness, with hints of happiness and a meditative quality."
|
| 72 |
+
print('Generating for prompt: ' + src)
|
| 73 |
+
|
| 74 |
inputs = tokenizer(src, return_tensors='pt', padding=True, truncation=True)
|
| 75 |
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
|
| 76 |
input_ids = input_ids.to(device)
|
| 77 |
attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
|
| 78 |
attention_mask = attention_mask.to(device)
|
| 79 |
+
|
| 80 |
+
# Generate the midi
|
| 81 |
output = model.generate(input_ids, attention_mask, max_len=2000,temperature = 1.0)
|
| 82 |
output_list = output[0].tolist()
|
| 83 |
generated_midi = r_tokenizer.decode(output_list)
|