Spaces:
Runtime error
Runtime error
Commit
·
8c28418
1
Parent(s):
2833bac
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,16 +39,16 @@ def generate(
|
|
| 39 |
ctx,
|
| 40 |
image_features,
|
| 41 |
token_count=200,
|
| 42 |
-
temperature=0
|
| 43 |
top_p=0.3,
|
| 44 |
presencePenalty = 0.1,
|
| 45 |
countPenalty = 0.1,
|
| 46 |
):
|
| 47 |
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
ctx = ctx.strip()
|
| 53 |
all_tokens = []
|
| 54 |
out_last = 0
|
|
@@ -56,9 +56,11 @@ def generate(
|
|
| 56 |
occurrence = {}
|
| 57 |
for i in range(int(token_count)):
|
| 58 |
if i == 0:
|
|
|
|
|
|
|
| 59 |
input_ids = pipeline.encode(ctx)
|
| 60 |
text_embs = model.w['emb.weight'][input_ids]
|
| 61 |
-
input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
|
| 62 |
out, state = model.forward(embs=input_embs, state=None)
|
| 63 |
else:
|
| 64 |
input_ids = [token]
|
|
|
|
| 39 |
ctx,
|
| 40 |
image_features,
|
| 41 |
token_count=200,
|
| 42 |
+
temperature=1.0,
|
| 43 |
top_p=0.3,
|
| 44 |
presencePenalty = 0.1,
|
| 45 |
countPenalty = 0.1,
|
| 46 |
):
|
| 47 |
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
| 48 |
+
alpha_frequency = countPenalty,
|
| 49 |
+
alpha_presence = presencePenalty,
|
| 50 |
+
token_ban = [], # ban the generation of some tokens
|
| 51 |
+
token_stop = [0, 261]) # stop generation whenever you see any token here
|
| 52 |
ctx = ctx.strip()
|
| 53 |
all_tokens = []
|
| 54 |
out_last = 0
|
|
|
|
| 56 |
occurrence = {}
|
| 57 |
for i in range(int(token_count)):
|
| 58 |
if i == 0:
|
| 59 |
+
prefix_ids = pipeline.encode("User: ")
|
| 60 |
+
prefix_embs = model.w['emb.weight'][prefix_ids]
|
| 61 |
input_ids = pipeline.encode(ctx)
|
| 62 |
text_embs = model.w['emb.weight'][input_ids]
|
| 63 |
+
input_embs = torch.cat((prefix_embs, image_features, text_embs), dim=0)[-ctx_limit:]
|
| 64 |
out, state = model.forward(embs=input_embs, state=None)
|
| 65 |
else:
|
| 66 |
input_ids = [token]
|