1.全局状态
例子来解释
import gradio as gr scores = [] def track_score(score): scores.append(score) top_scores = sorted(scores, reverse=True)[:3] return top_scores demo = gr.Interface( track_score, gr.Number(label="Score"), gr.JSON(label="Top Scores") ) demo.launch()
如上所述,scores,就可以在某函数中访问。
- 多用户访问,每次访问的分数都保存到scores列表
- 并并返回前三的分数
2.会话状态
Gradio 支持的另一种数据持久化类型是会话状态,其中数据在页面会话中跨多个提交持久化。但是,数据_不会_在模型的不同用户之间共享。要在会话状态中存储数据,您需要做三件事:
- 将一个额外的参数传递到您的函数中,该参数表示界面的状态。
- 在函数结束时,返回状态的更新值作为额外的返回值。
- 创建时添加
'state'
输入和输出组件'state'``Interface
聊天机器人是一个您需要会话状态的示例 - 您想要访问用户以前提交的内容,但您不能将聊天历史存储在全局变量中,因为那样聊天历史会在不同用户之间混乱。
import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") def user(message, history): return "", history + [[message, None]] # bot_message = random.choice(["Yes", "No"]) # history[-1][1] = bot_message # time.sleep(1) # return history # def predict(input, history=[]): # # tokenize the new input sentence def bot(history): user_message = history[-1][0] new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt') # append the new user input tokens to the chat history bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) # generate a response history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist() # convert the tokens to text, and then split the responses into lines response = tokenizer.decode(history[0]).split("<|endoftext|>") response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list return history with gr.Blocks() as demo: chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( bot, chatbot, chatbot ) clear.click(lambda: None, None, chatbot, queue=False) demo.launch()