Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import random | |
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
| def init_model(): | |
| tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
| model = DistilBertForSequenceClassification.from_pretrained('khizon/distilbert-unreliable-news-eng-4L', num_labels = 2) | |
| return tokenizer, model | |
| def download_dataset(): | |
| url = 'https://drive.google.com/drive/folders/11mRvsHAkggFEJvG4axH4mmWI6FHMQp7X?usp=sharing' | |
| data = 'data/nela_gt_2018_site_split' | |
| os.system(f'gdown --folder {url} -O {data}') | |
| def jsonl_to_df(file_path): | |
| with open(file_path) as f: | |
| lines = f.read().splitlines() | |
| df_inter = pd.DataFrame(lines) | |
| df_inter.columns = ['json_element'] | |
| df_inter['json_element'].apply(json.loads) | |
| return pd.json_normalize(df_inter['json_element'].apply(json.loads)) | |
| def load_test_df(): | |
| file_path = os.path.join('test_sub.jsonl') | |
| test_df = jsonl_to_df(file_path) | |
| test_df = pd.get_dummies(test_df, columns = ['label']) | |
| return test_df | |
| def predict(model, tokenizer, data): | |
| labels = data[['label_0', 'label_1']] | |
| labels = torch.tensor(labels, dtype=torch.float32) | |
| encoding = tokenizer.encode_plus( | |
| data['title'], | |
| ' [SEP] ' + data['content'], | |
| add_special_tokens=True, | |
| max_length = 512, | |
| return_token_type_ids = False, | |
| padding = 'max_length', | |
| truncation = 'only_second', | |
| return_attention_mask = True, | |
| return_tensors = 'pt' | |
| ) | |
| output = model(**encoding) | |
| return correct_preds(output['logits'], labels) | |
| def predict_new(model, tokenizer, title, content): | |
| encoding = tokenizer.encode_plus( | |
| title, | |
| ' [SEP] ' + content, | |
| add_special_tokens=True, | |
| max_length = 512, | |
| return_token_type_ids = False, | |
| padding = 'max_length', | |
| truncation = 'only_second', | |
| return_attention_mask = True, | |
| return_tensors = 'pt' | |
| ) | |
| output = model(**encoding) | |
| preds = F.softmax(output['logits'], dim = 1) | |
| p_idx = torch.argmax(preds, dim = 1) | |
| return 'reliable' if p_idx > 0 else 'unreliable' | |
| def correct_preds(preds, labels): | |
| preds = torch.nn.functional.softmax(preds, dim = 1) | |
| p_idx = torch.argmax(preds, dim=1) | |
| l_idx = torch.argmax(labels, dim=0) | |
| pred_label = 'reliable' if p_idx > 0 else 'unreliable' | |
| correct = True if (p_idx == l_idx).sum().item() > 0 else False | |
| return pred_label, correct | |
| if __name__ == '__main__': | |
| df = load_test_df() | |
| tokenizer, model = init_model() | |
| st.title("Unreliable News classifier") | |
| mode = st.radio( | |
| '', ('Test article', 'Input own article') | |
| ) | |
| if mode == 'Test article': | |
| if st.button('Get random article'): | |
| idx = np.random.randint(0, len(df)) | |
| sample = df.iloc[idx] | |
| prediction, correct = predict(model, tokenizer, sample) | |
| label = 'reliable' if sample['label_1'] > sample['label_0'] else 'unreliable' | |
| st.header(sample['title']) | |
| if correct: | |
| st.success(f'Prediction: {prediction}') | |
| else: | |
| st.error(f'Prediction: {prediction}') | |
| st.caption(f'Source: {sample["source"]} ({label})') | |
| # if len(sample['content']) > 300: | |
| # sample['content'] = sample['content'][:300] | |
| temp = [] | |
| for idx, word in enumerate(sample['content'].split()): | |
| if (random.randint(0, 99)> 45) and idx > 0: | |
| word = '▒'*len(word) | |
| temp.append(word) | |
| sample['content'] = ' '.join(temp) | |
| st.markdown(sample['content']) | |
| else: | |
| title = st.text_input('Article title', 'Test title') | |
| content = st.text_area('Article content', 'Lorem ipsum') | |
| if st.button('Submit'): | |
| pred = predict_new(model, tokenizer, title, content) | |
| st.markdown(f'Prediction: {pred}') | |
| # st.success('success') |