Transformers 4.37 中文文档(九十九)(7)https://developer.aliyun.com/article/1564047
StoppingCriteria
StoppingCriteria 可用于更改生成过程何时停止(除了 EOS 标记)。请注意,这仅适用于我们的 PyTorch 实现。
class transformers.StoppingCriteria
( )
所有可以在生成过程中应用的停止标准的抽象基类。
如果您的停止标准取决于scores
输入,请确保将return_dict_in_generate=True, output_scores=True
传递给generate
。
__call__
( input_ids: LongTensor scores: FloatTensor **kwargs )
参数
input_ids
(torch.LongTensor
,形状为(batch_size, sequence_length)
) — 词汇表中输入序列标记的索引。
可以使用 AutoTokenizer 来获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode()和 PreTrainedTokenizer.call
()。
什么是输入 ID?scores
(torch.FloatTensor
,形状为(batch_size, config.vocab_size)
) — 语言建模头的预测分数。这些可以是 SoftMax 之前每个词汇标记的分数,也可以是 SoftMax 之后每个词汇标记的分数。如果此停止标准取决于scores
输入,请确保将return_dict_in_generate=True, output_scores=True
传递给generate
。kwargs
(Dict[str, Any]
,可选) — 其他特定停止标准的关键字参数。
class transformers.StoppingCriteriaList
( iterable = () )
__call__
( input_ids: LongTensor scores: FloatTensor **kwargs )
参数
input_ids
(torch.LongTensor
,形状为(batch_size, sequence_length)
) — 词汇表中输入序列标记的索引。
可以使用 AutoTokenizer 获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode()和 PreTrainedTokenizer.call
()。
什么是输入 ID?scores
(torch.FloatTensor
of shape(batch_size, config.vocab_size)
) — 语言建模头的预测分数。这些可以是 SoftMax 之前每个词汇标记的分数,也可以是 SoftMax 之后每个词汇标记的分数。如果此停止标准取决于scores
输入,请确保您传递return_dict_in_generate=True, output_scores=True
给generate
。kwargs
(Dict[str, Any]
, 可选) — 其他特定停止标准的 kwargs。
class transformers.MaxLengthCriteria
( max_length: int max_position_embeddings: Optional = None )
参数
max_length
(int
) — 输出序列在标记数量上可以具有的最大长度。max_position_embeddings
(int
, 可选) — 模型的最大长度,由模型的config.max_position_embeddings
属性定义。
这个类可以用来在生成的标记数超过max_length
时停止生成。请注意,对于仅解码器类型的 transformers,这将包括初始提示的标记。
__call__
( input_ids: LongTensor scores: FloatTensor **kwargs )
参数
input_ids
(torch.LongTensor
of shape(batch_size, sequence_length)
) — 输入序列标记在词汇表中的索引。
可以使用 AutoTokenizer 获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode()和 PreTrainedTokenizer.call
()。
什么是输入 ID?scores
(torch.FloatTensor
of shape(batch_size, config.vocab_size)
) — 语言建模头的预测分数。这些可以是 SoftMax 之前每个词汇标记的分数,也可以是 SoftMax 之后每个词汇标记的分数。如果此停止标准取决于scores
输入,请确保您传递return_dict_in_generate=True, output_scores=True
给generate
。kwargs
(Dict[str, Any]
, 可选) — 其他特定停止标准的 kwargs。
class transformers.MaxTimeCriteria
( max_time: float initial_timestamp: Optional = None )
参数
max_time
(float
) — 生成的最大允许时间(以秒为单位)。initial_time
(float
, 可选, 默认为time.time()
) — 允许生成的开始时间。
这个类可以用来在完整生成超过一定时间时停止生成。默认情况下,当初始化此函数时,时间将开始计算。您可以通过传递initial_time
来覆盖这一点。
__call__
( input_ids: LongTensor scores: FloatTensor **kwargs )
参数
input_ids
(torch.LongTensor
of shape(batch_size, sequence_length)
) — 输入序列标记在词汇表中的索引。
可以使用 AutoTokenizer 获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call
()。
什么是输入 ID?scores
(torch.FloatTensor
of shape(batch_size, config.vocab_size)
) — 语言建模头的预测分数。这些可以是 SoftMax 之前每个词汇标记的分数,也可以是 SoftMax 之后每个词汇标记的分数。如果这个停止标准依赖于scores
输入,确保你传递return_dict_in_generate=True, output_scores=True
给generate
。kwargs
(Dict[str, Any]
, 可选) — 其他特定的停止标准参数。
约束
Constraint 可以用来强制生成结果中包含特定的标记或序列。请注意,这仅适用于我们的 PyTorch 实现。
class transformers.Constraint
( )
所有可以在生成过程中应用的约束的抽象基类。它必须定义约束如何被满足。
所有继承 Constraint 的类必须遵循的要求
completed = False while not completed: _, completed = constraint.update(constraint.advance())
将始终终止(停止)。
advance
( ) → export const metadata = 'undefined';token_ids(torch.tensor)
返回
token_ids(torch.tensor
)
必须是一个可索引的标记列表的张量,而不是某个整数。
调用时,返回一个标记,这个标记会使这个约束更接近被满足一步。
copy
( stateful = False ) → export const metadata = 'undefined';constraint(Constraint)
返回
constraint(Constraint
)
与被调用的相同的约束。
创建这个约束的一个新实例。
does_advance
( token_id: int )
读取一个标记并返回它是否推进了进度。
remaining
( )
返回 advance()
完成这个约束还需要多少步骤。
reset
( )
重置这个约束的状态到初始化状态。我们会在约束的实现被不想要的标记中断时调用这个方法。
test
( )
测试这个约束是否已经正确定义。
update
( token_id: int ) → export const metadata = 'undefined';stepped(bool)
返回
stepped(bool
)
这个约束是否变得更接近被满足一步。completed(bool
): 这个约束是否已经被这个生成的标记完全满足。reset (bool
): 这个约束是否已经被这个生成的标记重置了进度。
读取一个标记并返回指示其推进程度的布尔值。这个函数会更新这个对象的状态,不像 does_advance(self, token_id: int)
。
这不是为了测试某个特定的标记是否会推进进度;而是为了更新它的状态,就好像它已经被生成了。如果 token_id != desired token(参考 PhrasalConstraint 中的 else 语句),这变得很重要。
class transformers.PhrasalConstraint
( token_ids: List )
参数
token_ids
(List[int]
)— 必须由输出生成的 token 的 id。
Constraint 强制要求输出中包含一个有序的 token 序列。
class transformers.DisjunctiveConstraint
( nested_token_ids: List )
参数
nested_token_ids
(List[List[int]]
)— 一个单词列表,其中每个单词都是一个 id 列表。通过从单词列表中生成一个单词来满足此约束。
一个特殊的 Constraint,通过满足几个约束中的一个来实现。
class transformers.ConstraintListState
( constraints: List )
参数
constraints
(List[Constraint]
)— 必须由 beam 评分器满足的 Constraint 对象列表。
用于跟踪 beam 评分器通过一系列约束的进度的类。
advance
( )
要生成的 token 列表,以便我们可以取得进展。这里的“列表”并不意味着将完全满足约束的 token 列表。
给定约束c_i = {t_ij | j == # of tokens}
,如果我们不处于通过特定约束c_i
进行进度的中间阶段,我们返回:
[t_k1 for k in indices of unfulfilled constraints]
如果我们处于约束的中间阶段,那么我们返回:[t_ij]
,其中i
是正在进行的约束的索引,j
是约束的下一步。
虽然我们不关心哪个约束先被满足,但如果我们正在满足一个约束,那么这是我们唯一会返回的。
reset
( token_ids: Optional )
token_ids:到目前为止生成的 token,以重置通过约束的进度状态。
BeamSearch
class transformers.BeamScorer
( )
所有用于 beam_search()和 beam_sample()的 beam 评分器的抽象基类。
process
( input_ids: LongTensor next_scores: FloatTensor next_tokens: LongTensor next_indices: LongTensor **kwargs ) → export const metadata = 'undefined';UserDict
参数
input_ids
(形状为(batch_size * num_beams, sequence_length)
的torch.LongTensor
)— 词汇表中输入序列 token 的索引。
可以使用任何继承自 PreTrainedTokenizer 的类来获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode()和 PreTrainedTokenizer.call
()。
什么是输入 ID?next_scores
(形状为(batch_size, 2 * num_beams)
的torch.FloatTensor
)— 前2 * num_beams
个未完成的 beam 假设的当前分数。next_tokens
(形状为(batch_size, 2 * num_beams)
的torch.LongTensor
)— 与前2 * num_beams
个未完成的 beam 假设对应的input_ids
的 tokens。next_indices
(形状为(batch_size, 2 * num_beams)
的torch.LongTensor
)— 指示next_tokens
对应于哪个 beam 假设的 beam 索引。pad_token_id
(int
,可选)— 填充标记的 id。eos_token_id
(Union[int, List[int]]
,可选)— 结束序列标记的 id。可选择使用列表设置多个结束序列标记。beam_indices
(torch.LongTensor
,可选)— 指示每个标记对应于哪个 beam 假设的 beam 索引。group_index
(int
,可选)— beam 组的索引。与 group_beam_search()一起使用。
返回值
UserDict
由上述字段组成的字典:
next_beam_scores
(形状为(batch_size * num_beams)
的torch.FloatTensor
)— 所有未完成 beam 的更新分数。next_beam_tokens
(形状为(batch_size * num_beams)
的torch.FloatTensor
)— 要添加到未完成 beam_hypotheses 的下一个标记。next_beam_indices
(形状为(batch_size * num_beams)
的torch.FloatTensor
)— 指示下一个标记应添加到哪个 beam 的 beam 索引。
finalize
( input_ids: LongTensor next_scores: FloatTensor next_tokens: LongTensor next_indices: LongTensor max_length: int **kwargs ) → export const metadata = 'undefined';torch.LongTensor of shape (batch_size * num_return_sequences, sequence_length)
参数
input_ids
(形状为(batch_size * num_beams, sequence_length)
的torch.LongTensor
)— 词汇表中输入序列标记的索引。
可以使用任何继承自 PreTrainedTokenizer 的类来获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode()和 PreTrainedTokenizer.call
()。
什么是输入 ID?final_beam_scores
(形状为(batch_size * num_beams)
的torch.FloatTensor
)— 所有未完成 beam 的最终分数。final_beam_tokens
(形状为(batch_size * num_beams)
的torch.FloatTensor
)— 要添加到未完成 beam_hypotheses 的最后一个标记。final_beam_indices
(形状为(batch_size * num_beams)
的torch.FloatTensor
)— 指示final_beam_tokens
应添加到哪个 beam 的 beam 索引。pad_token_id
(int
,可选)— 填充标记的 id。eos_token_id
(Union[int, List[int]]
,可选)— 结束序列标记的 id。可选择使用列表设置多个结束序列标记。
返回值
torch.LongTensor
的形状为(batch_size * num_return_sequences, sequence_length)
生成的序列。第二维(sequence_length)要么等于max_length
,要么如果所有批次由于eos_token_id
而提前完成,则要短。
class transformers.BeamSearchScorer
( batch_size: int num_beams: int device: device length_penalty: Optional = 1.0 do_early_stopping: Union = False num_beam_hyps_to_keep: Optional = 1 num_beam_groups: Optional = 1 max_length: Optional = None )
参数
batch_size
(int
)—input_ids
的批量大小,用于并行运行标准 beam 搜索解码。num_beams
(int
)— beam 搜索的 beam 数量。device
(torch.device
)— 定义此BeamSearchScorer
实例将分配到的设备类型(例如,"cpu"
或"cuda"
)。length_penalty
(float
,可选,默认为 1.0)— 用于基于 beam 的生成的长度的指数惩罚。它作为指数应用于序列长度,然后用于分割序列的分数。由于分数是序列的对数似然(即负数),length_penalty
> 0.0 促进更长的序列,而length_penalty
< 0.0 鼓励更短的序列。do_early_stopping
(bool
或str
,可选,默认为False
) — 控制基于 beam 的方法(如 beam-search)的停止条件。接受以下值:True
,生成在有num_beams
个完整候选时停止;False
,应用启发式方法,当很难找到更好的候选时停止生成;"never"
,仅当不能有更好的候选时,beam 搜索过程才会停止(经典 beam 搜索算法)。num_beam_hyps_to_keep
(int
,可选,默认为 1) — 在调用 finalize()时应返回的 beam 假设数量。num_beam_groups
(int
,可选,默认为 1) — 将num_beams
分成多个组以确保不同组的 beam 之间的多样性。有关更多详细信息,请参阅此论文。max_length
(int
,可选) — 要生成的序列的最大长度。
BeamScorer 实现标准 beam 搜索解码。
部分改编自Facebook 的 XLM beam 搜索代码。
多样性 beam 搜索算法和实现的参考Ashwin Kalyan 的 DBS 实现
process
( input_ids: LongTensor next_scores: FloatTensor next_tokens: LongTensor next_indices: LongTensor pad_token_id: Optional = None eos_token_id: Union = None beam_indices: Optional = None group_index: Optional = 0 decoder_prompt_len: Optional = 0 )
finalize
( input_ids: LongTensor final_beam_scores: FloatTensor final_beam_tokens: LongTensor final_beam_indices: LongTensor max_length: int pad_token_id: Optional = None eos_token_id: Union = None beam_indices: Optional = None decoder_prompt_len: Optional = 0 )
class transformers.ConstrainedBeamSearchScorer
( batch_size: int num_beams: int constraints: List device: device length_penalty: Optional = 1.0 do_early_stopping: Union = False num_beam_hyps_to_keep: Optional = 1 num_beam_groups: Optional = 1 max_length: Optional = None )
参数
batch_size
(int
) — 并行运行标准 beam 搜索解码的input_ids
的批处理大小。num_beams
(int
) — beam 搜索的 beam 数量。constraints
(List[Constraint]
) — 以Constraint
对象表示的正约束列表,必须在生成输出中满足。有关更多信息,请阅读 Constraint 的文档。device
(torch.device
) — 定义此BeamSearchScorer
实例将分配到的设备类型(例如,"cpu"
或"cuda"
)。length_penalty
(float
,可选,默认为 1.0) — 用于基于 beam 的生成的长度的指数惩罚。它作为序列长度的指数应用,然后用于分割序列的分数。由于分数是序列的对数似然(即负数),length_penalty
> 0.0 促进更长的序列,而length_penalty
< 0.0 鼓励更短的序列。do_early_stopping
(bool
或str
,可选,默认为False
) — 控制基于 beam 的方法(如 beam-search)的停止条件。接受以下值:True
,生成在有num_beams
个完整候选时停止;False
,应用启发式方法,当很难找到更好的候选时停止生成;"never"
,仅当不能有更好的候选时,beam 搜索过程才会停止(经典 beam 搜索算法)。num_beam_hyps_to_keep
(int
,可选,默认为 1) — 在调用 finalize()时应返回的 beam 假设数量。num_beam_groups
(int
,可选,默认为 1)— 将num_beams
分成几组以确保不同组束之间的多样性。有关更多详细信息,请参阅此论文。max_length
(int
,可选)— 要生成的序列的最大长度。
BeamScorer 实现受限束搜索解码。
process
( input_ids: LongTensor next_scores: FloatTensor next_tokens: LongTensor next_indices: LongTensor scores_for_all_vocab: FloatTensor pad_token_id: Optional = None eos_token_id: Union = None beam_indices: Optional = None decoder_prompt_len: Optional = 0 ) → export const metadata = 'undefined';UserDict
参数
input_ids
(torch.LongTensor
,形状为(batch_size * num_beams, sequence_length)
)— 词汇表中输入序列标记的索引。
可以使用任何继承自 PreTrainedTokenizer 的类来获取索引。有关详细信息,请参阅 PreTrainedTokenizer.encode()和 PreTrainedTokenizer.call
()。
什么是输入 ID?next_scores
(torch.FloatTensor
,形状为(batch_size, 2 * num_beams)
)— 前2 * num_beams
个未完成束假设的当前分数。next_tokens
(torch.LongTensor
,形状为(batch_size, 2 * num_beams)
)— 与前2 * num_beams
个未完成束假设对应的标记的input_ids
。next_indices
(torch.LongTensor
,形状为(batch_size, 2 * num_beams)
)— 指示next_tokens
对应的束假设的束索引。scores_for_all_vocab
(torch.FloatTensor
,形状为(batch_size * num_beams, sequence_length)
)— 每个束假设的词汇表中所有标记的分数。pad_token_id
(int
,可选)— 填充标记的 ID。eos_token_id
(Union[int, List[int]]
,可选)— 结束序列标记的 ID。可以选择使用列表设置多个结束序列标记。beam_indices
(torch.LongTensor
,可选)— 指示每个标记对应的束假设的束索引。decoder_prompt_len
(int
,可选)— 包含在输入到解码器中的提示长度。
返回
UserDict
由上述字段组成的字典:
next_beam_scores
(torch.FloatTensor
,形状为(batch_size * num_beams)
)— 所有未完成束的更新分数。next_beam_tokens
(torch.FloatTensor
,形状为(batch_size * num_beams)
)— 要添加到未完成束假设的下一个标记。next_beam_indices
(torch.FloatTensor
,形状为(batch_size * num_beams)
)— 指示下一个标记应添加到哪个束中的束索引。
finalize
( input_ids: LongTensor final_beam_scores: FloatTensor final_beam_tokens: LongTensor final_beam_indices: LongTensor max_length: int pad_token_id: Optional = None eos_token_id: Union = None beam_indices: Optional = None decoder_prompt_len: Optional = 0 )
实用工具
transformers.top_k_top_p_filtering
( logits: FloatTensor top_k: int = 0 top_p: float = 1.0 filter_value: float = -inf min_tokens_to_keep: int = 1 )
参数
top_k
(int
,可选,默认为 0)— 如果大于 0,则仅保留具有最高概率的前 k 个标记(top-k 过滤)top_p
(float
,可选,默认为 1.0)— 如果小于 1.0,则仅保留累积概率大于等于 top_p 的前几个标记(nucleus 过滤)。Nucleus 过滤在 Holtzman 等人的论文中有描述(arxiv.org/abs/1904.09751
)。min_tokens_to_keep
(int
,可选,默认为 1)— 输出中每个批次示例保留的最小标记数。
使用 top-k 和/或 nucleus(top-p)过滤对数分布
来自:gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
transformers.tf_top_k_top_p_filtering
( logits top_k = 0 top_p = 1.0 filter_value = -inf min_tokens_to_keep = 1 )
参数
top_k
(int
,可选,默认为 0)— 如果> 0,则仅保留具有最高概率的前 k 个标记(top-k 过滤)top_p
(float
,可选,默认为 1.0)— 如果<1.0,则仅保留累积概率>= top_p 的前几个标记(nucleus 过滤)。Nucleus 过滤在 Holtzman 等人的论文中有描述(arxiv.org/abs/1904.09751
)min_tokens_to_keep
(int
,可选,默认为 1)— 输出中每个批示例要保留的最小标记数。
使用 top-k 和/或 nucleus(top-p)过滤对数分布
来自:gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
流媒体器
class transformers.TextStreamer
( tokenizer: AutoTokenizer skip_prompt: bool = False **decode_kwargs )
参数
tokenizer
(AutoTokenizer
)— 用于解码标记的标记器。skip_prompt
(bool
,可选,默认为False
)— 是否跳过提示以执行.generate()
。例如,对于聊天机器人很有用。decode_kwargs
(dict
,可选)— 传递给标记器的decode
方法的其他关键字参数。
简单的文本流媒体器,一旦形成完整单词,就会将标记打印到标准输出。
流媒体类的 API 仍在开发中,可能会在未来发生变化。
示例:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer >>> tok = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> streamer = TextStreamer(tok) >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
end
( )
刷新任何剩余的缓存并将换行符打印到标准输出。
on_finalized_text
( text: str stream_end: bool = False )
将新文本打印到标准输出。如果流结束,也会打印一个换行符。
put
( value )
接收标记,解码它们,并在形成完整单词时立即将它们打印到标准输出。
class transformers.TextIteratorStreamer
( tokenizer: AutoTokenizer skip_prompt: bool = False timeout: Optional = None **decode_kwargs )
参数
tokenizer
(AutoTokenizer
)— 用于解码标记的标记器。skip_prompt
(bool
,可选,默认为False
)— 是否跳过提示以执行.generate()
。例如,对于聊天机器人很有用。timeout
(float
,可选)— 文本队列的超时时间。如果为None
,队列将无限期阻塞。在单独的线程中调用.generate()
时,有助于处理异常。decode_kwargs
(dict
,可选)— 传递给标记器的decode
方法的其他关键字参数。
将可打印文本存储在队列中的流,供下游应用程序用作迭代器。这对于那些从以非阻塞方式访问生成的文本中受益的应用程序很有用(例如,在交互式 Gradio 演示中)。
流媒体类的 API 仍在开发中,可能会在未来发生变化。
示例:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer >>> from threading import Thread >>> tok = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> streamer = TextIteratorStreamer(tok) >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) >>> thread = Thread(target=model.generate, kwargs=generation_kwargs) >>> thread.start() >>> generated_text = "" >>> for new_text in streamer: ... generated_text += new_text >>> generated_text 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
on_finalized_text
( text: str stream_end: bool = False )
将新文本放入队列中。如果流结束,也将停止信号放入队列中。
缓存
class transformers.Cache
( )
所有缓存的基类。实际数据结构对于每个子类是特定的。
update
( key_states: Tensor value_states: Tensor layer_idx: int cache_kwargs: Optional = None )
参数
key_states
(torch.Tensor
)— 要缓存的新键状态。value_states
(torch.Tensor
)— 要缓存的新值状态。layer_idx
(int
)— 用于缓存状态的层的索引。cache_kwargs
(Dict[str, Any]
,optional
) — 缓存子类的额外参数。这些参数针对每个子类是特定的,并允许创建新类型的缓存。
使用新的key_states
和value_states
更新层layer_idx
的缓存。
class transformers.DynamicCache
( )
随着生成更多令牌,缓存会动态增长。这是生成模型的默认设置。
它将键和值状态存储为张量列表,每个层一个张量。每个张量的预期形状为[batch_size, num_heads, seq_len, head_dim]
。
update
( key_states: Tensor value_states: Tensor layer_idx: int cache_kwargs: Optional = None )
参数
key_states
(torch.Tensor
) — 要缓存的新键状态。value_states
(torch.Tensor
) — 要缓存的新值状态。layer_idx
(int
) — 要为其缓存状态的层的索引。cache_kwargs
(Dict[str, Any]
,optional
) — 缓存子类的额外参数。在DynamicCache
中不使用额外参数。
使用新的key_states
和value_states
更新层layer_idx
的缓存。
get_seq_length
( layer_idx: Optional = 0 )
返回缓存状态的序列长度。可以选择传递层索引。
reorder_cache
( beam_idx: LongTensor )
为束搜索重新排序缓存,给定所选的束索引。
to_legacy_cache
( )
将DynamicCache
实例转换为其在传统缓存格式中的等效形式。
from_legacy_cache
( past_key_values: Optional = None )
将传统缓存格式中的缓存转换为等效的DynamicCache
。
class transformers.SinkCache
( window_length: int num_sink_tokens: int )
参数
window_length
(int
) — 上下文窗口的长度。num_sink_tokens
(int
) — 沉没令牌的数量。有关更多信息,请参阅原始论文。
如Attention Sinks 论文中描述的缓存。它允许模型在超出其上下文窗口长度的情况下生成,而不会失去对话流畅性。随着丢弃过去的令牌,模型将失去生成依赖于被丢弃上下文的令牌的能力。
它将键和值状态存储为张量列表,每个层一个张量。每个张量的预期形状为[batch_size, num_heads, seq_len, head_dim]
。
update
( key_states: Tensor value_states: Tensor layer_idx: int cache_kwargs: Optional = None )
参数
key_states
(torch.Tensor
) — 要缓存的新键状态。value_states
(torch.Tensor
) — 要缓存的新值状态。layer_idx
(int
) — 要为其缓存状态的层的索引。cache_kwargs
(Dict[str, Any]
,optional
) — 缓存子类的额外参数。在SinkCache
中可以使用以下参数:sin
,cos
和partial_rotation_size
。这些参数用于使用 RoPE 的模型,在令牌移位时重新计算旋转。
使用新的key_states
和value_states
更新层layer_idx
的缓存。
get_seq_length
( layer_idx: Optional = 0 )
返回缓存状态的序列长度。可以选择传递层索引。
reorder_cache
( beam_idx: LongTensor )
重新排列缓存以进行波束搜索,给定所选的波束索引。