Update app.py
Browse files
app.py
CHANGED
|
@@ -6,12 +6,9 @@ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
|
| 6 |
import gradio as gr
|
| 7 |
from gradio_rangeslider import RangeSlider
|
| 8 |
import time
|
| 9 |
-
import numba
|
| 10 |
-
from numba import objmode
|
| 11 |
|
| 12 |
is_stopped = False
|
| 13 |
|
| 14 |
-
@numba.jit(nopython=True)
|
| 15 |
def temperature_sampling(logits, temperature):
|
| 16 |
logits = logits / temperature
|
| 17 |
probabilities = torch.softmax(logits, dim=-1)
|
|
@@ -23,29 +20,6 @@ def stop_generation():
|
|
| 23 |
is_stopped = True
|
| 24 |
return "Generation stopped."
|
| 25 |
|
| 26 |
-
@numba.jit(nopython=False)
|
| 27 |
-
def generate_sequence(length, vocab_mlm, seq, new_seq, τ, input_text):
|
| 28 |
-
for i in range(length):
|
| 29 |
-
if is_stopped:
|
| 30 |
-
return "output.csv", pd.DataFrame()
|
| 31 |
-
|
| 32 |
-
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 33 |
-
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
| 34 |
-
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 35 |
-
attn_idx = torch.tensor(attn_idx).to(device)
|
| 36 |
-
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
| 37 |
-
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
| 38 |
-
|
| 39 |
-
logits = model(idx_seq, idx_msa, attn_idx)
|
| 40 |
-
mask_logits = logits[0, mask_position.item(), :]
|
| 41 |
-
|
| 42 |
-
predicted_token_id = temperature_sampling(mask_logits, τ)
|
| 43 |
-
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
| 44 |
-
input_text[mask_position.item()] = predicted_token
|
| 45 |
-
padded_seq[mask_position.item()] = predicted_token.strip()
|
| 46 |
-
new_seq = padded_seq
|
| 47 |
-
return input_text
|
| 48 |
-
|
| 49 |
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
| 50 |
if seed =='random':
|
| 51 |
seed = random.randint(0,100000)
|
|
@@ -131,11 +105,31 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
|
| 131 |
|
| 132 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 133 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
|
|
|
| 134 |
gen_length = len(input_text)
|
| 135 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
generated_seq[1] = "[MASK]"
|
| 141 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
from gradio_rangeslider import RangeSlider
|
| 8 |
import time
|
|
|
|
|
|
|
| 9 |
|
| 10 |
is_stopped = False
|
| 11 |
|
|
|
|
| 12 |
def temperature_sampling(logits, temperature):
|
| 13 |
logits = logits / temperature
|
| 14 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
|
| 20 |
is_stopped = True
|
| 21 |
return "Generation stopped."
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
|
| 24 |
if seed =='random':
|
| 25 |
seed = random.randint(0,100000)
|
|
|
|
| 105 |
|
| 106 |
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 107 |
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
| 108 |
+
|
| 109 |
gen_length = len(input_text)
|
| 110 |
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
| 111 |
+
for i in range(length):
|
| 112 |
+
if is_stopped:
|
| 113 |
+
return "output.csv", pd.DataFrame()
|
| 114 |
+
|
| 115 |
+
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
| 116 |
+
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
| 117 |
+
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
| 118 |
+
attn_idx = torch.tensor(attn_idx).to(device)
|
| 119 |
+
|
| 120 |
+
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
| 121 |
+
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
| 122 |
+
|
| 123 |
+
logits = model(idx_seq,idx_msa, attn_idx)
|
| 124 |
+
mask_logits = logits[0, mask_position.item(), :]
|
| 125 |
+
|
| 126 |
+
predicted_token_id = temperature_sampling(mask_logits, τ)
|
| 127 |
+
|
| 128 |
+
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
| 129 |
+
input_text[mask_position.item()] = predicted_token
|
| 130 |
+
padded_seq[mask_position.item()] = predicted_token.strip()
|
| 131 |
+
new_seq = padded_seq
|
| 132 |
+
generated_seq = input_text
|
| 133 |
|
| 134 |
generated_seq[1] = "[MASK]"
|
| 135 |
input_ids = vocab_mlm.__getitem__(generated_seq)
|