NuExtract / ml.py
Alexandre-Numind's picture
Update ml.py
7a28cee verified
raw
history blame contribute delete
No virus
1.16 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
import torch
import json
import json
import re
import numpy as np
def create_prompt(text, template, examples):
template = json.dumps(json.loads(template),indent = 4)
prompt = "<|input|>\n### Template:\n"+template+"\n"
if examples[0]:
example1 = json.dumps(json.loads(examples[0]),indent = 4)
prompt+= "### Example:\n"+example1+"\n"
if examples[1]:
example2 = json.dumps(json.loads(examples[1]),indent = 4)
prompt+= "### Example:\n"+example1+"\n"
if examples[2]:
example3 = json.dumps(json.loads(examples[1]),indent = 4)
prompt+= "### Example:\n"+example3+"\n"
prompt += "### Text:\n"+text+'''\n<|output|>'''
return prompt
def generate_answer_short(prompt,model, tokenizer):
model_input = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3000).to("cuda")
with torch.no_grad():
gen = tokenizer.decode(model.generate(**model_input, max_new_tokens=1500)[0], skip_special_tokens=True)
print(gen.split("<|output|>")[1])
return gen.split("<|output|>")[1].split("<|end-output|>")[0]