写在最前面
这次该我汇报啦
许愿明天讲的顺利,问的都会
课堂讨论
讲+提问1个小时
但是在讨论的过程中,感觉逐步抽丝挖掘到了核心原理:
之前的理解:借助代码-LLM中的编码丰富结构化代码信息
最后的理解:如果能设置一个方法,让大模型能对自己输出的有所理解,那么效果会更好。这篇论文是通过代码结构和提示来实现这个的,理论上文字也可以
汇报
CODEIE:代码生成大模型能更好的进行少样本信息提取这项工作
接下来我将从这四个板块展开介绍
研究背景
这篇论文于2023年5月发布在arXiv,随后9月发表在ACL-NLP顶刊。其作者来自复旦大学和华师大学。
命名实体识别(NER)和关系抽取(RE)
首先,让我们了解一些自然语言处理(NLP)的背景知识。
信息抽取
的目标是从未经结构化处理的文本中提取出结构化信息。这个领域涵盖了各种任务,包括命名实体识别(NER)和关系抽取(RE)。
NER
用于识别文本中的命名实体,例如人名、地名和组织名,
而RE
则用于提取文本中实体之间的关系信息。
相关工作
- 为了在一个统一框架内处理这些不同的任务,近期的研究建议将输出结构线性化为非结构化字符串,并采用
序列生成模型
来解决信息抽取(IE)任务。
尽管这种线性化方法在具有充足训练数据的情况下取得了良好的成果,但在少样本情况下性能不佳。 - 鉴于
大型语言模型
具有强大的少样本学习能力,本论文旨在充分利用它们来解决少样本IE任务,特别是NER和RE任务。
通常,对于文本分类等NLP任务,以前的工作将任务重新构建为文本到文本生成的形式,并利用自然语言-llm(如GPT-3)来生成答案。 - 然而,由于IE任务具有复杂的内部结构,以往的线性化方法的输出结构通常显得不够自然,导致了预训练时的输出格式与推理时的输出格式不匹配。
因此,在使用这些扁平化方法时,通常需要复杂的解码策略
来将输出后处理为有效的结构。
作者动机
总结一下为什么作者选择了这一方法
。
表1总结了用于IE任务的中等规模模型、NL-LLM和Code-LLM之间的高层次差异。
此前的工作未能在一个统一的框架下充分利用大型模型进行少样本学习,特别是在处理结构化任务方面。这篇论文的模型成功弥补了这两个限制。
为什么这两个限制会对结果产生重大影响
呢?
首先,大型模型具有适应少样本数据的能力
其次,由于线性输出不够自然,通常需要更复杂的解码策略。
研究方案
因此,这篇论文提出了一种全新的思路,通过使用带有结构化代码风格提示的Code-LLM,来弥合预训练和推理阶段之间的输出差异,从而实现了IE任务的统一框架并获得更出色的结果。
这一方法的核心思想是将这两个IE任务框架化为代码生成任务
,并借助代码-LLM中的编码丰富结构化代码信息
,从而使这些IE任务有更好的结果。
实例
这是一个示例:
通过代码风格提示和Python字典的键,如“文本”和“类型”,可以组合它们成一个与NER(命名实体识别)示例等价的Python函数。
研究方案
下面我们对具体方案展开介绍。
方案预览
论文涵盖了NER和RE两项任务,
首先将原始IE任务转化为代码样式,其中(换PPT)Python函数名称表示任务,文档字符串说明任务目标,初始化空列表用于保存输出,描述性注释提供了提示以将命名实体放入列表中。(换PPT)这些元素被组合为“代码提示x”。
“结构化目标y”,将每个基本信息单元(NER的一对实体和RE的三元组)表示为Python字典,yc为字典列表。
对于NER任务,键包括“text”和“type”,值分别是实体跨度和实体类型。
对于RE任务,键包括实体类型和关系类型。然后,这些输入被传递给代码生成大模型,并得到输出。
左侧显示的文字是对前面图表的简明描述。
此外,GPT-3和CodeX都是OpenAI的模型,而CodeX是在GPT-3的基础上进行改进的,这两者有着相同的起源。
由于大型模型API的黑盒特性,无法对这些大型模型进行微调,因此这篇论文致力于探索上下文学习的方法,包括使用标记样本。
上下文提示学习是如何具体体现的呢?
在这篇论文中,任务被转化为代码表示,然后将它们连接在一起,构建了一个上下文演示,其中包括x1y1x2y2直到xnyn,最后有一个准备预测的xc。这个上下文被输入到模型中,生成输出yc,其格式与y1y2yn相似,通常保持了Python语法,容易还原成原始结构。
鉴于少样本训练容易受到高方差的影响,该论文为每个实验采用不同的随机种子运行了三次,并报告了度量指标的均值和标准差。
这篇论文已经开源了这项研究的代码,如果有兴趣的朋友可以前去查看。
实验
接下来,我们对这篇文章的实验部分进行梳理
数据集和基线模型
在实验结果部分,论文涵盖了七个NLP任务的数据集,采用了中等规模的预训练模型作为基线
评价指标
评价指标是常规的NER和RE任务性能度量指标。
实体的偏移量和实体类型与黄金实体匹配,那么实体跨度预测就是正确的。
如果关系类型正确且其实体对应的偏移量和类型正确,则关系预测是正确的。
实验方案对比
结果表明:
1、(表3)LLMs (GPT-3和Codex)在少样本设置下,比中等大小的模型(T5和UIE)实现了优越的性能。
2、比较不同提示设计的效果
(表4)突出显示的分别是“text”和“code”提示类型,code提示效果更好。
(图3)Codex胜过GPT-3,代码提示优于文本提示。Codex比GPT3效果好,代码提示比文本提示好。
并且在1次学习设置下,CODEIE将基准上的性能提高了60%以上,表现了强大的少样本学习能力
值得注意的是,代码提示对GPT-3更有益,尽管它并没有专门针对代码数据进行训练。
3、控制变量对比实验
作者进行了一些控制变量的对比实验,以探讨导致模型性能优越的因素。
第一个是格式一致性Format Consistency
介绍下条件困惑度
,这是一种衡量生成的文本在给定条件下生成文本的可预测性,也就是在给定上下文前缀的条件下,模型生成下一个字符的概率的准确性的度量。
较低的条件困惑度值表示生成的文本更符合所期望的条件。
图4,在7个数据集上,文本提示和代码提示的输入格式和模型之间的条件困惑度。
第二个是模型忠实度
分为两个指标:
1、一个是结构忠实度Structure Fidelity
,顾名思义是生成文本的结构
图5:比较提示学习和不同组合LLM的结构错误率,output的形式不对
2、一个是语义忠实度Semantic Fidelity
,生成文本的语义忠诚度
表5:实验中检测到的语义错误样本,output中语义不对,比如预定义实体类型中不存在的实体类型。
结果表明,GPT-3倾向于生成自由形式的结果,Codex更忠实于上下文中提供的演示,因此对于IE任务更可预测
第三个,细粒度性能Fine-grained Performance
结果表明(a)代码提示提高了模型的查准率和查全率;
(b)与GPT- 3相比,Codex在NER任务上实现了更高的召回率和相当的精度,并在RE任务上实现了更高的精度和召回率。
研究总结
最后对这篇论文进行总结。
这篇论文提出的方法相对于其他顶刊论文来说,更加简单有效。它通过领域迁移,将文本生成转化为代码生成,设计上下文提示学习以替代仅提供API的大型模型微调。
未来的工作可考虑:
考虑设计更良好的代码格式提示。
目前是在黑盒模型GPT3和Codex上进行实验,之后可以在开源模型上进一步微调。
以及,在非英文数据集(如中文数据集)上探索本文模型的实用性。
这就是《CODEIE: Large Code Generation Models are Better Few-Shot Information Extractors》的主要内容和关键观点。感谢大家的聆听。
补充
条件困惑度
条件复杂度是一种用于评估模型在给定上下文下生成下一个标记的难度和质量的度量,
模型的目标是尽可能减少条件复杂度,以获得更高质量的生成结果
。
语言模型的条件复杂度和代码模型的条件复杂度通常都基于困惑度(Perplexity)来计算,但有一些细微的差异,具体取决于模型的类型和应用领域。
1. 语言模型的条件复杂度:
语言模型的条件复杂度用于评估模型在给定上下文下生成下一个单词的质量。它通常采用以下方式计算:
- 给定一组文本数据(通常是测试集),将每个句子划分为多个标记(例如,单词或字符)。
- 模型接受一个前缀文本(通常是句子的一部分)作为输入,并尝试生成下一个标记。
- 通过将模型生成的标记与实际下一个标记进行比较,计算困惑度。困惑度通常使用交叉熵损失来计算。
数学表达式如下:
P e r p l e x i t y = exp ( − 1 N ∑ i = 1 N log P ( w i ∣ w 1 , w 2 , … , w i − 1 ) ) Perplexity = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, w_2, \ldots, w_{i-1})\right)Perplexity=exp(−N1i=1∑NlogP(wi∣w1,w2,…,wi−1))
其中:
- P ( w i ∣ w 1 , w 2 , … , w i − 1 ) P(w_i | w_1, w_2, \ldots, w_{i-1})P(wi∣w1,w2,…,wi−1)表示在给定前缀 w 1 , w 2 , … , w i − 1 w_1, w_2, \ldots, w_{i-1}w1,w2,…,wi−1的条件下,模型生成标记 w i w_iwi的概率。
- N NN表示测试集中的标记总数。
2. 代码模型的条件复杂度:
对于代码生成模型,条件复杂度的计算方式与语言模型类似,但有一些不同之处:
- 给定一组源代码或代码片段,将其分解为标记(例如,编程语言中的令牌或代码块)。
- 模型接受前缀代码或上下文,并尝试生成下一个代码标记。
- 使用交叉熵或其他损失函数计算条件复杂度,以评估模型在给定上下文下生成下一个代码标记的质量。
计算方式在代码模型中的具体实施可能因任务和模型架构而异,但基本原理与语言模型相似。
代码
https://github.com/artpli/CodeIE
我们的代码主要是从 UIE 和 CoCoGen 代码仓库修改而来的。
我们更新了源文件的初始版本。有关数据处理和代码的更多信息将在后续更新中提供。
通知:
一个不太好的消息是,Codex 模型现在已被 OpenAI 弃用,这将对复制本文产生重大影响。一些可能的解决方案包括
申请 OpenAI的研究人员访问计划
或访问 Azure OpenAI 服务上的 Codex
。由于我们使用的是 OpenAI 的闭源API,因此我们不知道它们背后的技术细节,例如使用的特定预训练语料库。因此,在评估我们的论文时可能存在潜在的数据污染问题。如果可能的话,我们会在更多的开源模型上评估我们的方法。
Codex被弃用了,但OpenAI建议所有用户从Codex切换到GPT-3.5 Turbo,它既可以完成编码任务,又可以补充灵活的自然语言功能。
代码目录
api
openai_api_wrapper.py
import os from typing import Dict, Any import openai from src.prompt.constants import END, END_LINE openai.api_key = os.getenv("OPENAI_API_KEY") class OpenaiAPIWrapper: @staticmethod def call(prompt: str, max_tokens: int, engine: str) -> dict: response = openai.Completion.create( engine=engine, prompt=prompt, temperature=0.0, max_tokens=max_tokens, top_p=1, frequency_penalty=0, presence_penalty=0, stop=[END, END_LINE], # logprobs=3, best_of=1 ) return response @staticmethod def parse_response(response) -> Dict[str, Any]: text = response["choices"][0]["text"] return text
这段代码是一个Python脚本,它包含了一个名为OpenaiAPIWrapper
的类,该类提供了与OpenAI的GPT-3 API交互的功能。这个类的主要目的是通过给定的提示(prompt)来请求GPT-3生成文本,并解析生成的文本
。
具体来说,这段代码的功能如下:
- 导入必要的模块和库,包括
os
、openai
,以及一些类型提示(typing
模块)和其他自定义模块。 - 设置OpenAI API密钥,它是在环境变量中查找的,这个密钥用于身份验证并授权访问OpenAI的GPT-3服务。
- 定义了一个名为
OpenaiAPIWrapper
的类,该类包含两静态方法。
call
方法接受三个参数:prompt
(提示文本,用于生成文本的输入)、max_tokens
(生成的文本的最大令牌数)、engine
(指定GPT-3的引擎)。它使用这些参数通过OpenAI API发送请求,并返回生成的文本作为响应。parse_response
方法接受一个响应对象,从中提取生成的文本,并返回它作为字符串。
这个代码段的主要目的是:为了通过OpenAI的API与GPT-3交互,以便生成自然语言文本,然后将生成的文本提取出来以供后续处理或显示。
query_openai_over_tasks.py
给定一个提示文件(prompt file)和一个任务文件(task file),任务文件包含以下字段:
- input_prompt:用于提示Codex的代码。
- reference_code:期望的完成代码。
- reference_graph:期望的图形(可能是与代码相关的图形)。
对于每个input_prompt,运行Codex的推断,并将以下字段添加到输出文件:
- generated_code:生成的代码。
- generated_graph:生成的图形。
文件可以包含其他元数据,但上述字段是必需的。
""" Given a prompt file and path to a task file with the following fields: 1. input_prompt: the code used to prompt codex 2. reference_code: expected completed code 3. reference_graph: expected graph Runs inference over codex for each input_prompt, and adds the following fields to the output file: 4. generated_code: generated code 5. generated_graph: generated graph The file can contain other metadata, but the fields above are required. """ import os import sys sys.path.append(os.getcwd()) from datetime import datetime import shutil import time import openai import pandas as pd from tqdm import tqdm import logging import os import pickle from src.converters.structure_converter import StructureConverter from src.converters.get_converter import ConverterFactory from openai_api_wrapper import OpenaiAPIWrapper from src.prompt.constants import END from src.utils.file_utils import load_yaml,load_schema logging.basicConfig(level=logging.INFO) def run(task_file_path: str, num_tasks: int, start_idx: int, output_file_path: str, prompt_path: str, keep_writing_output: bool, engine: str, max_tokens:int, max_requests_per_min: int, schema_path:str, map_config_path:str, start_cut_num:int): tasks = pd.read_json(task_file_path, orient='records', lines=True) converter = ConverterFactory.get_converter(args.job_type,schema_folder=schema_path, map_config_path=map_config_path) if num_tasks != -1: tasks = tasks.iloc[start_idx: start_idx+num_tasks] fixed_prompt_text = read_prompt(prompt_path) results = [] cache = load_cache(output_file_path) num_requests = 0 time_begin = time.time() failed_list = [] max_failed_time = 10 max_failed_taskes = 10 for task_idx, task in tqdm(tasks.iterrows(), total=len(tasks)): is_success = False tmp_failed_time = 0 while is_success is False and tmp_failed_time < max_failed_time: cut_prompt_examples_list = [None, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] cut_prompt_examples_list = cut_prompt_examples_list[start_cut_num:] for cut_prompt_examples in cut_prompt_examples_list: try: num_requests += 1 request_per_minute = maintain_request_per_minute( num_requests=num_requests, time_begin=time_begin, max_requests_per_min=max_requests_per_min, task_idx=task_idx) logging.info("\n") logging.info( f"Task {task_idx} > request/minute = {request_per_minute:.2f}") task_results = run_task(task=task, fixed_prompt_text=fixed_prompt_text, cache=cache, converter=converter, cut_prompt_examples=cut_prompt_examples, task_idx=task_idx, engine=engine, max_tokens=max_tokens) task_results['id'] = task_idx results.append(task_results) is_success = True break except openai.error.InvalidRequestError as e: logging.info( f"InvalidRequestError: {e}, trying with shorter prompt (cut_prompt_examples={cut_prompt_examples + 1 if cut_prompt_examples is not None else 1})") # sleep for a bit to further avoid rate limit exceeded exceptions if cut_prompt_examples != cut_prompt_examples_list[-1]: time.sleep(5) continue else: tmp_failed_time = max_failed_time logging.info(f"Failed too many times: {tmp_failed_time}") except Exception as e: # something else went wrong logging.info(f"Task {task_idx} failed: {e}") tmp_failed_time += 1 time.sleep(5 * tmp_failed_time) logging.info(f"Restart task {task_idx}") break if is_success and keep_writing_output: pd.DataFrame(results).to_json( output_file_path, orient='records', lines=True) if is_success == False: failed_list.append(task_idx) logging.info(f"Task {task_idx} failed {max_failed_time} times, skipped and recorded.") if failed_list != []: print ("failed list:\n", failed_list) if len(failed_list) > max_failed_taskes: print ("too many failed taskes. exit().") exit(0) print( f"Ran {len(results)} out of {len(tasks)} tasks ({len(results) / len(tasks):.2%})") pd.DataFrame(results).to_json( output_file_path, orient='records', lines=True) if failed_list != []: print ("failed list:\n", failed_list) output_path = output_file_path.rstrip('.jsonl') + '_failed_list.pkl' with open(output_path,"w") as fout: pickle.dump(failed_list,fout) print ("failed list saved into: ", output_path) def run_task(task: dict, fixed_prompt_text: str, cache: dict, converter: StructureConverter, task_idx: int, engine: str, max_tokens: int, cut_prompt_examples: int = None) -> dict: """Runs the task, and returns the results. Args: task (dict): The task input fixed_prompt_text (str): Used for cases where the input prompt is fixed cache (dict): cache of previous results converter (GraphPythonConverter): A graph-python converter to parse results cut_prompt_examples (int, optional): If provided, the first `cut_prompt_examples` examples are deleted. Prevents 4096 errors. Defaults to None. Returns: dict: A dictionary with the results. """ start_time = time.time() prompt_text = fixed_prompt_text if fixed_prompt_text is not None else task['prompt'] if cut_prompt_examples is not None: prompt_text_parts = prompt_text.split(END) prompt_text = END.join(prompt_text_parts[cut_prompt_examples:]) if task['input_prompt'] in cache: logging.info( f"Task {task_idx} > Using cached result for {task['input_prompt']}") codex_response = cache[task['input_prompt']]["codex_response"] else: codex_response = query_codex(task, prompt_text, engine, max_tokens=max_tokens) completed_code = get_completed_code(task, codex_response) task_results = {k: v for (k, v) in task.items()} task_results["codex_response"] = codex_response task_results["generated_code"] = completed_code task_results["elapsed_time"] = time.time() - start_time return task_results def maintain_request_per_minute(num_requests: int, time_begin: float, max_requests_per_min: int, task_idx: int) -> float: request_per_minute = get_request_per_minute(num_requests, time_begin) logging.info("\n") while request_per_minute > max_requests_per_min: logging.info( f"Task {task_idx} > Sleeping! (Requests/minute = {request_per_minute:.2f} > {max_requests_per_min:.2f})") time.sleep(1) request_per_minute = get_request_per_minute( num_requests, time_begin) return request_per_minute def read_prompt(prompt_path): if prompt_path is None: return None with open(prompt_path, "r") as f: prompt = f.read() return prompt def load_cache(output_file_path: str): """We don't want to query codex repeatedly for the same input. If an output file exists, this function creates a "cache" of the results. The cache is implemented as a hashmap keyed by `input_prompt`, and maps to the entire output entry Args: output_file_path (str): _description_ """ if not os.path.exists(output_file_path): return {} else: # make a backup of the file already there shutil.copyfile(output_file_path, output_file_path + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")) shutil.copy(output_file_path, output_file_path + ".bak") cache_data = pd.read_json( output_file_path, orient='records', lines=True) cache = {row['input_prompt']: row.to_dict() for _, row in cache_data.iterrows()} return cache def query_codex(task: dict, prompt_text: str, engine: str, max_tokens: int): prompt = f"{prompt_text} {task['input_prompt']}" response = OpenaiAPIWrapper.call( prompt=prompt, max_tokens=max_tokens, engine=engine) return response def get_completed_code(task: dict, codex_response: dict) -> str: completed_code = OpenaiAPIWrapper.parse_response(codex_response) all_code = f"{task['input_prompt']}{completed_code}" # NOTE: space is already taken care of, no need to add it again, otherwise indentation will be off return all_code def get_request_per_minute(num_request: int, begin_time: float) -> float: elapsed_time = time.time() - begin_time request_per_minute = (num_request / elapsed_time) * 60 return request_per_minute if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--task_file_path", type=str, required=True) parser.add_argument("--num_tasks", type=int, required=True) parser.add_argument("--start_idx", type=int, required=True) parser.add_argument("--output_file_path", type=str, required=True) parser.add_argument("--prompt_path", type=str, required=False, default=None) parser.add_argument("--job_type", type=str, required=True, choices=ConverterFactory.supported_converters) parser.add_argument("--keep_writing_output", action="store_true", default=True) parser.add_argument("--engine", type=str, required=True) parser.add_argument("--max_requests_per_min", type=int, default=10) parser.add_argument("--max_tokens", type=int, default=280) parser.add_argument("--schema_path", type=str, required=True) parser.add_argument("--map_config_path", type=str, required=True) parser.add_argument("--start_cut_num", type=int, default=0) args = parser.parse_args() run(task_file_path=args.task_file_path, num_tasks=args.num_tasks,start_idx=args.start_idx, output_file_path=args.output_file_path, prompt_path=args.prompt_path, keep_writing_output=args.keep_writing_output, engine=args.engine, max_requests_per_min=args.max_requests_per_min, max_tokens=args.max_tokens,schema_path=args.schema_path, map_config_path=args.map_config_path,start_cut_num=args.start_cut_num)
分段解读 def run:记录任务的执行情况,保存任务的结果,以及对于失败任务的记录和保存
def run(task_file_path: str, num_tasks: int, start_idx: int, output_file_path: str, prompt_path: str, keep_writing_output: bool, engine: str, max_tokens:int, max_requests_per_min: int, schema_path:str, map_config_path:str, start_cut_num:int): tasks = pd.read_json(task_file_path, orient='records', lines=True) converter = ConverterFactory.get_converter(args.job_type,schema_folder=schema_path, map_config_path=map_config_path)
这段代码定义了一个名为run
的Python函数,该函数接受多个参数,并执行一些任务。以下是每个参数的解释:
task_file_path
(str):任务文件的路径,包含任务信息。num_tasks
(int):任务的数量。start_idx
(int):任务的起始索引。output_file_path
(str):输出文件的路径,用于将结果写入。prompt_path
(str):提示文件的路径,可能包含用于代码生成的输入提示。keep_writing_output
(bool):一个布尔值,指示是否保持写入输出。可能用于控制写入的方式。engine
(str):引擎名称,用于执行任务,可能与某种代码生成引擎相关。max_tokens
(int):最大令牌数,可能用于限制生成的代码的长度。max_requests_per_min
(int):每分钟的最大请求数。schema_path
(str):模式文件的路径,可能包含任务的数据模式。map_config_path
(str):映射配置文件的路径,用于数据转换或映射。start_cut_num
(int):一个整数,可能用于指示某种切割操作的起始数量。
在函数内部,它首先读取了一个任务文件并将其存储在名为tasks
的DataFrame中。然后,它使用ConverterFactory
类中的get_converter
方法来创建一个converter
对象,该对象用于根据作业类型(args.job_type
)执行某些转换任务,并使用给定的模式文件和映射配置文件。
if num_tasks != -1: tasks = tasks.iloc[start_idx: start_idx+num_tasks] fixed_prompt_text = read_prompt(prompt_path) results = [] cache = load_cache(output_file_path) num_requests = 0 time_begin = time.time() failed_list = [] max_failed_time = 10 max_failed_taskes = 10
在给定的代码段中,如果num_tasks
参数的值不等于-1,那么它会对任务进行切片操作,保留从start_idx
到start_idx+num_tasks
的子集。
然后,代码从一个名为prompt_path
的文件中读取提示文本(这个提示文本包含用于代码生成的输入提示),并将其存储在fixed_prompt_text
变量中。
接下来,代码初始化了一些变量,包括:
results
:用于存储任务的结果的空列表。cache
:通过加载一个输出文件来初始化的,这个文件用于缓存结果。num_requests
:用于跟踪已经发出的请求数量的变量。time_begin
:用于记录时间的变量,可能用于计算运行时间。
此外,还定义了以下变量:
failed_list
:用于存储失败任务的列表。max_failed_time
:最大失败时间,可能用于指示失败任务的最大允许时间。max_failed_taskes
:最大失败任务数,可能用于指示允许的最大失败任务数量。
这些变量在接下来的代码中可能会用于跟踪和处理任务的执行以及失败的情况。接下来的代码段可能包括任务的执行和结果的收集,以及处理失败任务的逻辑。
for task_idx, task in tqdm(tasks.iterrows(), total=len(tasks)): is_success = False tmp_failed_time = 0 while is_success is False and tmp_failed_time < max_failed_time: cut_prompt_examples_list = [None, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] cut_prompt_examples_list = cut_prompt_examples_list[start_cut_num:] for cut_prompt_examples in cut_prompt_examples_list: try: num_requests += 1 request_per_minute = maintain_request_per_minute( num_requests=num_requests, time_begin=time_begin, max_requests_per_min=max_requests_per_min, task_idx=task_idx) logging.info("\n") logging.info( f"Task {task_idx} > request/minute = {request_per_minute:.2f}") task_results = run_task(task=task, fixed_prompt_text=fixed_prompt_text, cache=cache, converter=converter, cut_prompt_examples=cut_prompt_examples, task_idx=task_idx, engine=engine, max_tokens=max_tokens) task_results['id'] = task_idx results.append(task_results) is_success = True break except openai.error.InvalidRequestError as e: logging.info( f"InvalidRequestError: {e}, trying with shorter prompt (cut_prompt_examples={cut_prompt_examples + 1 if cut_prompt_examples is not None else 1})") # sleep for a bit to further avoid rate limit exceeded exceptions if cut_prompt_examples != cut_prompt_examples_list[-1]: time.sleep(5) continue else: tmp_failed_time = max_failed_time logging.info(f"Failed too many times: {tmp_failed_time}") except Exception as e: # something else went wrong logging.info(f"Task {task_idx} failed: {e}") tmp_failed_time += 1 time.sleep(5 * tmp_failed_time) logging.info(f"Restart task {task_idx}") break if is_success and keep_writing_output: pd.DataFrame(results).to_json( output_file_path, orient='records', lines=True) if is_success == False: failed_list.append(task_idx) logging.info(f"Task {task_idx} failed {max_failed_time} times, skipped and recorded.") if failed_list != []: print ("failed list:\n", failed_list) if len(failed_list) > max_failed_taskes: print ("too many failed taskes. exit().") exit(0)
这段代码是一个循环,用于处理任务列表中的每个任务。以下是代码的主要逻辑:
- 循环迭代任务列表中的每个任务(通过
for task_idx, task in tqdm(tasks.iterrows(), total=len(tasks))
实现迭代)。 - 在每次迭代中,设置了一个
is_success
标志,用于跟踪任务是否成功完成。还有一个tmp_failed_time
变量,用于跟踪任务失败的时间。 - 在一个while循环内,当
is_success
为False
且tmp_failed_time
小于max_failed_time
时,会尝试执行任务。该while循环将重试直到任务成功或达到最大失败时间。 - 在内部循环中,任务会被分成多个部分,通过
cut_prompt_examples_list
来控制。每个部分都会尝试执行。如果任务成功完成,is_success
标志将设置为True
,并跳出内部循环。如果任务失败,则会捕获异常,并根据不同的异常类型采取不同的措施,包括减少任务提示的长度、等待一段时间,以及记录失败的任务。 - 如果
is_success
为True
,将任务结果添加到results
列表中,然后写入到输出文件中。 - 如果
is_success
为False
,表示任务在允许的失败时间内无法成功执行,将任务索引添加到failed_list
列表中,并记录失败的任务。如果failed_list
不为空,它会将失败的任务索引打印出来。如果失败的任务数超过了max_failed_taskes
,则终止程序。
总之,这段代码负责循环处理任务列表中的每个任务,监视任务的成功与失败,并根据情况采取相应的措施,包括重试任务、记录失败的任务以及终止程序。这个循环是任务执行和处理失败任务的核心部分。
print( f"Ran {len(results)} out of {len(tasks)} tasks ({len(results) / len(tasks):.2%})") pd.DataFrame(results).to_json( output_file_path, orient='records', lines=True) if failed_list != []: print ("failed list:\n", failed_list) output_path = output_file_path.rstrip('.jsonl') + '_failed_list.pkl' with open(output_path,"w") as fout: pickle.dump(failed_list,fout) print ("failed list saved into: ", output_path)
这段代码是在处理所有任务后执行的一些收尾工作:
- 首先,它打印出已经运行了多少个任务,以及总共有多少任务,并以百分比形式表示已完成任务的比例。
- 然后,它将任务结果存储为JSON文件(output_file_path),使用
pd.DataFrame(results).to_json
方法将results
列表的内容写入文件。这个文件将包含每个任务的生成结果。 - 接下来,如果存在失败的任务(
failed_list
不为空),它会打印出失败任务的索引,并将这个失败列表保存为一个.pkl
文件。保存的文件名是在原始output_file_path
的基础上,去掉扩展名.jsonl
并添加_failed_list.pkl
后缀。 - 最后,它会打印出失败任务列表已保存的文件路径。
这段代码用于记录任务的执行情况,保存任务的结果,以及对于失败任务的记录和保存。它提供了任务执行的总结信息,并保存了任务的结果和失败列表。
def run_task:执行一个给定任务,与Codex进行交互,获取生成的代码,记录执行时间,并返回任务的结果
def run_task(task: dict, fixed_prompt_text: str, cache: dict, converter: StructureConverter, task_idx: int, engine: str, max_tokens: int, cut_prompt_examples: int = None) -> dict: """Runs the task, and returns the results. Args: task (dict): The task input fixed_prompt_text (str): Used for cases where the input prompt is fixed cache (dict): cache of previous results converter (GraphPythonConverter): A graph-python converter to parse results cut_prompt_examples (int, optional): If provided, the first `cut_prompt_examples` examples are deleted. Prevents 4096 errors. Defaults to None. Returns: dict: A dictionary with the results. """ start_time = time.time() prompt_text = fixed_prompt_text if fixed_prompt_text is not None else task['prompt'] if cut_prompt_examples is not None: prompt_text_parts = prompt_text.split(END) prompt_text = END.join(prompt_text_parts[cut_prompt_examples:]) if task['input_prompt'] in cache: logging.info( f"Task {task_idx} > Using cached result for {task['input_prompt']}") codex_response = cache[task['input_prompt']]["codex_response"] else: codex_response = query_codex(task, prompt_text, engine, max_tokens=max_tokens) completed_code = get_completed_code(task, codex_response) task_results = {k: v for (k, v) in task.items()} task_results["codex_response"] = codex_response task_results["generated_code"] = completed_code task_results["elapsed_time"] = time.time() - start_time return task_results
这段代码定义了一个名为run_task
的函数,它用于运行一个任务并返回结果。以下是函数的主要参数和功能:
task
(dict):任务的输入,是一个包含任务信息的字典。fixed_prompt_text
(str):用于固定输入提示文本的字符串,如果不为None
,则会使用它而不是任务字典中的prompt
字段。cache
(dict):包含先前结果的缓存字典。converter
(StructureConverter):用于解析结果的转换器对象。cut_prompt_examples
(int,可选):一个整数,如果提供,将删除提示文本的前cut_prompt_examples
个示例以防止4096错误。默认为None
。
函数的主要功能如下:
- 计时开始,记录开始时间。
- 根据输入参数,确定要使用的提示文本,如果
fixed_prompt_text
不为None
,则使用它,否则使用任务字典中的prompt
字段。如果提供了cut_prompt_examples
,则截取提示文本的一部分以防止4096错误。 - 检查缓存中是否已经有了任务的结果,如果有,则从缓存中获取Codex的响应,否则使用
query_codex
函数查询Codex以获取响应。 - 获取Codex响应后,通过
get_completed_code
函数获取生成的代码。 - 将任务结果存储在一个字典中,包括任务的所有输入信息、Codex的响应、生成的代码和任务执行所花费的时间。
- 最后,返回包含任务结果的字典。
这个函数的主要目的是执行一个给定任务,与Codex进行交互,获取生成的代码,记录执行时间,并返回任务的结果。
def maintain_request_per_minute:控制每分钟的请求数以避免超出最大请求数
def maintain_request_per_minute(num_requests: int, time_begin: float, max_requests_per_min: int, task_idx: int) -> float: request_per_minute = get_request_per_minute(num_requests, time_begin) logging.info("\n") while request_per_minute > max_requests_per_min: logging.info( f"Task {task_idx} > Sleeping! (Requests/minute = {request_per_minute:.2f} > {max_requests_per_min:.2f})") time.sleep(1) request_per_minute = get_request_per_minute( num_requests, time_begin) return request_per_minute
这段代码定义了一个名为maintain_request_per_minute
的函数,用于控制每分钟的请求数以避免超出最大请求数。以下是函数的主要参数和功能:
num_requests
(int):已经发出的请求数。time_begin
(float):开始计时的时间。max_requests_per_min
(int):每分钟的最大请求数限制。task_idx
(int):任务的索引,用于记录日志。
函数的主要功能如下:
- 计算当前每分钟的请求速率,通过调用
get_request_per_minute
函数,传递已发出的请求数和开始计时的时间。 - 如果当前的请求速率超过了最大请求数限制,就会进入一个while循环。在循环中,它会记录日志,指示正在等待以减少请求速率。
- 在每次循环中,它会等待1秒,然后再次计算请求速率。这样,它会一直等待,直到请求速率不再超过最大请求数限制。
- 一旦请求速率在允许的范围内,它会返回当前请求速率。
这个函数的目的是确保不会超出每分钟的最大请求数限制,以遵守请求速率的规则。如果请求速率太高,它会等待一段时间,直到速率降到允许的水平。
def read_prompt:读取提示
def read_prompt(prompt_path): if prompt_path is None: return None with open(prompt_path, "r") as f: prompt = f.read() return prompt
这段代码定义了一个名为read_prompt
的函数,用于从文件中读取提示文本。以下是函数的主要参数和功能:
prompt_path
:提示文本文件的路径。
函数的主要功能如下:
- 首先,它检查
prompt_path
是否为None
。如果prompt_path
为None
,则函数直接返回None
,表示没有可用的提示文本。 - 如果
prompt_path
不为None
,则使用with
语句打开文件,读取文件中的文本内容,并将其存储在prompt
变量中。 - 最后,函数返回读取的提示文本。
这个函数的目的是从指定的文件中读取提示文本,并将其返回,以供后续任务使用。如果文件路径为None
,则返回None
表示没有可用的提示文本。
def load_cache:创建一个缓存以存储查询结果,以避免重复查询相同的输入
def load_cache(output_file_path: str): """We don't want to query codex repeatedly for the same input. If an output file exists, this function creates a "cache" of the results. The cache is implemented as a hashmap keyed by `input_prompt`, and maps to the entire output entry Args: output_file_path (str): _description_ """ if not os.path.exists(output_file_path): return {} else: # make a backup of the file already there shutil.copyfile(output_file_path, output_file_path + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")) shutil.copy(output_file_path, output_file_path + ".bak") cache_data = pd.read_json( output_file_path, orient='records', lines=True) cache = {row['input_prompt']: row.to_dict() for _, row in cache_data.iterrows()} return cache
这段代码定义了一个名为load_cache
的函数,用于创建一个缓存以存储查询结果,以避免重复查询相同的输入。以下是函数的主要参数和功能:
output_file_path
:用于存储缓存的文件路径。
函数的主要功能如下:
- 首先,它检查是否存在指定的
output_file_path
文件。如果文件不存在,它会返回一个空的缓存字典(空的哈希映射)。 - 如果文件存在,它会执行以下操作:
- 创建一个具有当前时间戳作为后缀的备份文件,以便保留文件的备份。
- 创建一个名为
output_file_path.bak
的文件,作为原始文件的备份副本。 - 从
output_file_path
文件中读取数据,以解析缓存的内容。 - 创建一个字典
cache
,其中每个缓存条目的键是input_prompt
,值是整个输出条目的字典表示。
- 最后,函数返回创建的缓存字典。
这个函数的目的是在指定的文件中创建一个缓存,用于存储查询的结果。如果文件不存在,它返回一个空的缓存字典。如果文件存在,它会创建一个包含缓存数据的字典,以便在以后的查询中可以快速查找和检索已缓存的结果。同时,它也会对原始文件进行备份,以便在需要时可以还原。