Update app.py
Browse files
app.py
CHANGED
|
@@ -100,7 +100,7 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
|
|
| 100 |
if is_stopped:
|
| 101 |
return pd.DataFrame(), "output.csv"
|
| 102 |
|
| 103 |
-
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq,
|
| 104 |
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
| 105 |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 106 |
attn_idx = torch.tensor(attn_idx).to(device)
|
|
@@ -115,7 +115,8 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
|
|
| 115 |
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
| 116 |
input_text[mask_position.item()] = predicted_token
|
| 117 |
padded_seq[mask_position.item()] = predicted_token.strip()
|
| 118 |
-
|
|
|
|
| 119 |
generated_seq = input_text
|
| 120 |
|
| 121 |
generated_seq[1] = "[MASK]"
|
|
|
|
| 100 |
if is_stopped:
|
| 101 |
return pd.DataFrame(), "output.csv"
|
| 102 |
|
| 103 |
+
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 104 |
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
| 105 |
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 106 |
attn_idx = torch.tensor(attn_idx).to(device)
|
|
|
|
| 115 |
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
| 116 |
input_text[mask_position.item()] = predicted_token
|
| 117 |
padded_seq[mask_position.item()] = predicted_token.strip()
|
| 118 |
+
new_seq = padded_seq
|
| 119 |
+
|
| 120 |
generated_seq = input_text
|
| 121 |
|
| 122 |
generated_seq[1] = "[MASK]"
|