Transformers 4.37 中文文档(十三)(7)https://developer.aliyun.com/article/1564952
自动模型用于生成口罩
class transformers.AutoModelForMaskGeneration
( *args **kwargs )
TFAutoModelForMaskGeneration
class transformers.TFAutoModelForMaskGeneration
( *args **kwargs )
AutoModelForSeq2SeqLM
class transformers.AutoModelForSeq2SeqLM
( *args **kwargs )
这是一个通用的模型类,当使用 from_pretrained()类方法或 from_config()类方法创建时,将实例化为库的模型类之一(带有序列到序列语言建模头)。
这个类不能直接使用__init__()实例化(会抛出错误)。
from_config
( **kwargs )
参数
config(PretrainedConfig)-选择要实例化的模型类基于配置类:
- BartConfig 配置类:BartForConditionalGeneration(BART 模型)
- BigBirdPegasusConfig 配置类:BigBirdPegasusForConditionalGeneration(BigBird-Pegasus 模型)
- BlenderbotConfig 配置类:BlenderbotForConditionalGeneration(Blenderbot 模型)
- BlenderbotSmallConfig 配置类:BlenderbotSmallForConditionalGeneration(BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类:EncoderDecoderModel(编码器解码器模型)
- FSMTConfig 配置类: FSMTForConditionalGeneration (FairSeq 机器翻译模型)
- GPTSanJapaneseConfig 配置类: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)
- LEDConfig 配置类: LEDForConditionalGeneration (LED 模型)
- LongT5Config 配置类: LongT5ForConditionalGeneration (LongT5 模型)
- M2M100Config 配置类: M2M100ForConditionalGeneration (M2M100 模型)
- MBartConfig 配置类: MBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: MT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: MarianMTModel (Marian 模型)
- MvpConfig 配置类: MvpForConditionalGeneration (MVP 模型)
- NllbMoeConfig 配置类: NllbMoeForConditionalGeneration (NLLB-MOE 模型)
- PLBartConfig 配置类: PLBartForConditionalGeneration (PLBart 模型)
- PegasusConfig 配置类: PegasusForConditionalGeneration (Pegasus 模型)
- PegasusXConfig 配置类: PegasusXForConditionalGeneration (PEGASUS-X 模型)
- ProphetNetConfig 配置类: ProphetNetForConditionalGeneration (ProphetNet 模型)
- SeamlessM4TConfig 配置类: SeamlessM4TForTextToText (SeamlessM4T 模型)
- SeamlessM4Tv2Config 配置类:SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型)
- SwitchTransformersConfig 配置类:SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)
- T5Config 配置类:T5ForConditionalGeneration (T5 模型)
- UMT5Config 配置类:UMT5ForConditionalGeneration (UMT5 模型)
- XLMProphetNetConfig 配置类:XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
从配置实例化库中的模型类(带有序列到序列语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained() 来加载模型权重。
示例:
>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM >>> # Download configuration from huggingface.co and cache. >>> config = AutoConfig.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_config(config)
from_pretrained
( *model_args **kwargs )
参数
pretrained_model_name_or_path(str或os.PathLike) — 可以是以下之一:
- 一个字符串,即在 huggingface.co 上托管的预训练模型的 模型 id。有效的模型 id 可以位于根级别,如
bert-base-uncased,或者在用户或组织名称下进行命名空间化,如dbmdz/bert-base-german-cased。 - 一个指向使用 save_pretrained() 保存的模型权重的 目录 的路径,例如
./my_model_directory/。 - 一个指向 tensorflow 索引检查点文件 的路径或 url(例如,
./tf_model/model.ckpt.index)。在这种情况下,from_tf应设置为True,并且应将配置对象提供为config参数。使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型后,此加载路径比较慢。
model_args(额外的位置参数,可选) — 将传递给底层模型__init__()方法。config(PretrainedConfig, 可选) — 用于模型的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:
- 模型是库提供的模型(使用预训练模型的 模型 id 字符串加载)。
- 模型是使用 save_pretrained() 保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path并且在目录中找到名为 config.json 的配置 JSON 文件来加载模型。
state_dict(Dict[str, torch.Tensor], 可选) — 用于替代从保存的权重文件加载的状态字典的状态字典。
如果您想从预训练配置创建模型,但加载自己的权重,可以使用此选项。不过,在这种情况下,您应该检查使用 save_pretrained()和 from_pretrained()是否不是更简单的选项。cache_dir(str或os.PathLike,可选) — 下载的预训练模型配置应该缓存在其中的目录路径,如果不使用标准缓存。from_tf(bool, 可选, 默认为False) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path参数的文档字符串)。force_download(bool, 可选, 默认为False) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。resume_download(bool, 可选, 默认为False) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。proxies(Dict[str, str], 可选) — 一个代理服务器字典,按协议或端点使用,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}。这些代理在每个请求中使用。output_loading_info(bool,可选, 默认为False) — 是否还返回一个包含缺失键、意外键和错误消息的字典。local_files_only(bool,可选, 默认为False) — 是否仅查看本地文件(例如,不尝试下载模型)。revision(str, 可选, 默认为"main") — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision可以是 git 允许的任何标识符。trust_remote_code(bool, 可选, 默认为False) — 是否允许在 Hub 上定义自定义模型的代码。此选项应仅对您信任的存储库设置为True,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。code_revision(str, 可选, 默认为"main") — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision可以是 git 允许的任何标识符。kwargs(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True)。根据是否提供或自动加载了config,行为会有所不同:
- 如果提供了配置
config,**kwargs将直接传递给底层模型的__init__方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs的每个键将用提供的kwargs值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__函数。
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类基于配置对象的 model_type 属性进行选择(如果可能,作为参数传递或从 pretrained_model_name_or_path 加载),或者当缺失时,通过在 pretrained_model_name_or_path 上使用模式匹配来回退:
bart— BartForConditionalGeneration (BART 模型)bigbird_pegasus— BigBirdPegasusForConditionalGeneration (BigBird-Pegasus 模型)blenderbot— BlenderbotForConditionalGeneration (Blenderbot 模型)blenderbot-small— BlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)encoder-decoder— EncoderDecoderModel (编码器解码器模型)fsmt— FSMTForConditionalGeneration (FairSeq 机器翻译模型)gptsan-japanese— GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese 模型)led— LEDForConditionalGeneration (LED 模型)longt5— LongT5ForConditionalGeneration (LongT5 模型)m2m_100— M2M100ForConditionalGeneration (M2M100 模型)marian— MarianMTModel (Marian 模型)mbart— MBartForConditionalGeneration (mBART 模型)mt5— MT5ForConditionalGeneration (MT5 模型)mvp— MvpForConditionalGeneration (MVP 模型)nllb-moe— NllbMoeForConditionalGeneration (NLLB-MOE 模型)pegasus— PegasusForConditionalGeneration (Pegasus 模型)pegasus_x— PegasusXForConditionalGeneration (PEGASUS-X 模型)plbart— PLBartForConditionalGeneration (PLBart 模型)prophetnet— ProphetNetForConditionalGeneration (ProphetNet 模型)seamless_m4t— SeamlessM4TForTextToText (SeamlessM4T 模型)seamless_m4t_v2— SeamlessM4Tv2ForTextToText (SeamlessM4Tv2 模型)switch_transformers— SwitchTransformersForConditionalGeneration (SwitchTransformers 模型)t5— T5ForConditionalGeneration (T5 模型)umt5— UMT5ForConditionalGeneration (UMT5 模型)xlm-prophetnet— XLMProphetNetForConditionalGeneration (XLM-ProphetNet 模型)
默认情况下,模型处于评估模式,使用 model.eval()(例如,dropout 模块被停用)。要训练模型,您应该首先使用 model.train() 将其设置回训练模式
示例:
>>> from transformers import AutoConfig, AutoModelForSeq2SeqLM >>> # Download model and configuration from huggingface.co and cache. >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> # Update configuration during loading >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base", output_attentions=True) >>> model.config.output_attentions True >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) >>> config = AutoConfig.from_pretrained("./tf_model/t5_tf_model_config.json") >>> model = AutoModelForSeq2SeqLM.from_pretrained( ... "./tf_model/t5_tf_checkpoint.ckpt.index", from_tf=True, config=config ... )
TFAutoModelForSeq2SeqLM
class transformers.TFAutoModelForSeq2SeqLM
( *args **kwargs )
这是一个通用的模型类,在使用 from_pretrained() 类方法或 from_config() 类方法创建时,将作为库中的模型类之一实例化(带有序列到序列语言建模头)。
这个类不能直接使用 __init__() 实例化(会报错)。
from_config
( **kwargs )
参数
config(PretrainedConfig) — 根据配置类选择要实例化的模型类:
- BartConfig 配置类: TFBartForConditionalGeneration (BART 模型)
- BlenderbotConfig 配置类: TFBlenderbotForConditionalGeneration (Blenderbot 模型)
- BlenderbotSmallConfig 配置类: TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类: TFEncoderDecoderModel (编码器解码器模型)
- LEDConfig 配置类: TFLEDForConditionalGeneration (LED 模型)
- MBartConfig 配置类: TFMBartForConditionalGeneration (mBART 模型)
- MT5Config 配置类: TFMT5ForConditionalGeneration (MT5 模型)
- MarianConfig 配置类: TFMarianMTModel (Marian 模型)
- PegasusConfig 配置类: TFPegasusForConditionalGeneration (Pegasus 模型)
- T5Config 配置类:TFT5ForConditionalGeneration(T5 模型)
从配置中实例化库中的一个模型类(带有序列到序列语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM >>> # Download configuration from huggingface.co and cache. >>> config = AutoConfig.from_pretrained("t5-base") >>> model = TFAutoModelForSeq2SeqLM.from_config(config)
from_pretrained
( *model_args **kwargs )
参数
pretrained_model_name_or_path(str或os.PathLike)— 可以是:
- 一个字符串,托管在 huggingface.co 模型存储库中的预训练模型的模型 ID。有效的模型 ID 可以位于根级别,如
bert-base-uncased,或命名空间下的用户或组织名称,如dbmdz/bert-base-german-cased。 - 一个包含使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/。 - 一个PyTorch 状态字典保存文件的路径或 URL(例如,
./pt_model/pytorch_model.bin)。在这种情况下,from_pt应设置为True,并且应将配置对象提供为config参数。这种加载路径比使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型要慢。
model_args(额外的位置参数,可选)— 将传递给底层模型__init__()方法。config(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当以下情况时,配置可以自动加载:
- 该模型是由库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path加载模型,并在目录中找到名为config.json的配置 JSON 文件。
cache_dir(str或os.PathLike,可选)— 如果不使用标准缓存,则应将下载的预训练模型配置缓存在其中的目录路径。from_pt(bool,可选,默认为False)— 从 PyTorch 检查点保存文件加载模型权重(参见pretrained_model_name_or_path参数的文档字符串)。force_download(bool,可选,默认为False)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。resume_download(bool,可选,默认为False)— 是否删除未完全接收的文件。如果存在这样的文件,将尝试恢复下载。proxies(Dict[str, str],可选)— 一个按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}。这些代理在每个请求中使用。output_loading_info(bool,可选,默认为False) — 是否还返回一个包含缺失键、意外键和错误消息的字典。local_files_only(bool,可选,默认为False) — 是否仅查看本地文件(例如,不尝试下载模型)。revision(str, 可选, 默认为"main") — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 id,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision可以是 git 允许的任何标识符。trust_remote_code(bool, 可选, 默认为False) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅对您信任的存储库设置为True,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。code_revision(str, 可选, 默认为"main") — 用于 Hub 上的代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 id,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision可以是 git 允许的任何标识符。kwargs(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并初始化模型(例如,output_attentions=True)。根据是否提供或自动加载了config,行为会有所不同:
- 如果提供了
config,**kwargs将直接传递给底层模型的__init__方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs将首先传递给配置类的初始化函数(from_pretrained())。kwargs的每个键对应一个配置属性,将用提供的kwargs值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__函数。
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的 model_type 属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path 加载),或者当缺少时,通过在 pretrained_model_name_or_path 上使用模式匹配来回退:
bart— TFBartForConditionalGeneration (BART 模型)blenderbot— TFBlenderbotForConditionalGeneration (Blenderbot 模型)blenderbot-small— TFBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)encoder-decoder— TFEncoderDecoderModel (编码器解码器模型)led— TFLEDForConditionalGeneration (LED 模型)marian— TFMarianMTModel (Marian 模型)mbart— TFMBartForConditionalGeneration (mBART 模型)mt5— TFMT5ForConditionalGeneration (MT5 模型)pegasus— TFPegasusForConditionalGeneration (Pegasus 模型)t5— TFT5ForConditionalGeneration (T5 模型)
示例:
>>> from transformers import AutoConfig, TFAutoModelForSeq2SeqLM >>> # Download model and configuration from huggingface.co and cache. >>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> # Update configuration during loading >>> model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-base", output_attentions=True) >>> model.config.output_attentions True >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) >>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json") >>> model = TFAutoModelForSeq2SeqLM.from_pretrained( ... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config ... )
FlaxAutoModelForSeq2SeqLM
class transformers.FlaxAutoModelForSeq2SeqLM
( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将作为库的模型类之一实例化(带有序列到序列语言建模头)。
这个类不能直接使用__init__()进行实例化(会抛出错误)。
from_config
( **kwargs )
参数
config(PretrainedConfig)— 选择要实例化的模型类基于配置类:
- BartConfig 配置类:FlaxBartForConditionalGeneration(BART 模型)
- BlenderbotConfig 配置类:FlaxBlenderbotForConditionalGeneration(Blenderbot 模型)
- BlenderbotSmallConfig 配置类:FlaxBlenderbotSmallForConditionalGeneration(BlenderbotSmall 模型)
- EncoderDecoderConfig 配置类:FlaxEncoderDecoderModel(编码器解码器模型)
- LongT5Config 配置类:FlaxLongT5ForConditionalGeneration(LongT5 模型)
- MBartConfig 配置类:FlaxMBartForConditionalGeneration(mBART 模型)
- MT5Config 配置类:FlaxMT5ForConditionalGeneration(MT5 模型)
- MarianConfig 配置类:FlaxMarianMTModel(Marian 模型)
- PegasusConfig 配置类:FlaxPegasusForConditionalGeneration(Pegasus 模型)
- T5Config 配置类:FlaxT5ForConditionalGeneration(T5 模型)
从配置实例化库的模型类之一(带有序列到序列语言建模头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()加载模型权重。
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM >>> # Download configuration from huggingface.co and cache. >>> config = AutoConfig.from_pretrained("t5-base") >>> model = FlaxAutoModelForSeq2SeqLM.from_config(config)
from_pretrained
( *model_args **kwargs )
参数
pretrained_model_name_or_path(str或os.PathLike)— 可以是:
- 一个字符串,预训练模型的模型标识符,托管在 huggingface.co 上的模型存储库中。有效的模型标识符可以位于根级别,如
bert-base-uncased,或者在用户或组织名称下命名空间,如dbmdz/bert-base-german-cased。 - 指向使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/。 - 指向PyTorch 状态字典保存文件的路径或 url(例如,
./pt_model/pytorch_model.bin)。在这种情况下,应将from_pt设置为True,并将配置对象提供为config参数。使用提供的转换脚本将 PyTorch 模型转换为 TensorFlow 模型并随后加载 TensorFlow 模型的加载路径比较慢。
model_args(额外的位置参数,可选)— 将传递给底层模型__init__()方法。config(PretrainedConfig,可选)— 模型使用的配置,而不是自动加载的配置。当以下情况时,可以自动加载配置:
- 该模型是库提供的模型(使用预训练模型的模型标识符字符串加载)。
- 模型是使用 save_pretrained()保存的,并通过提供保存目录重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path加载模型,并在目录中找到名为config.json的配置 JSON 文件。
cache_dir(str或os.PathLike,可选)— 下载预训练模型配置应缓存的目录路径,如果不使用标准缓存。from_pt(bool,可选,默认为False)— 从 PyTorch 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path参数的文档字符串)。force_download(bool,可选,默认为False)— 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。resume_download(bool,可选,默认为False)— 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。proxies(Dict[str, str],可选)— 一个按协议或端点使用的代理服务器字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}。每个请求都会使用代理。output_loading_info(bool,可选,默认为False)— 是否还返回包含缺失键、意外键和错误消息的字典。local_files_only(bool,可选,默认为False)— 是否仅查看本地文件(例如,不尝试下载模型)。revision(str,可选,默认为"main")— 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision可以是 git 允许的任何标识符。trust_remote_code(bool,可选,默认为False) — 是否允许在 Hub 上定义自定义模型并在其自己的建模文件中执行。此选项应仅在您信任的存储库中设置为True,并且您已经阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。code_revision(str,可选,默认为"main") — 用于 Hub 上代码的特定修订版本,如果代码存储在与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们在 huggingface.co 上使用基于 git 的系统来存储模型和其他工件,所以revision可以是 git 允许的任何标识符。kwargs(额外的关键字参数,可选) — 可用于更新配置对象(在加载后)并初始化模型(例如,output_attentions=True)。根据是否提供或自动加载了config,行为会有所不同:
- 如果提供了
config,**kwargs将直接传递给底层模型的__init__方法(我们假设配置的所有相关更新已经完成) - 如果未提供配置,
kwargs将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs的每个键将用提供的kwargs值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__函数。
从预训练模型实例化库中的一个模型类(带有序列到序列语言建模头)。
要实例化的模型类是根据配置对象的model_type属性选择的(作为参数传递或从pretrained_model_name_or_path加载,如果可能的话),或者当缺少时,通过在pretrained_model_name_or_path上使用模式匹配来回退:
bart— FlaxBartForConditionalGeneration (BART 模型)blenderbot— FlaxBlenderbotForConditionalGeneration (Blenderbot 模型)blenderbot-small— FlaxBlenderbotSmallForConditionalGeneration (BlenderbotSmall 模型)encoder-decoder— FlaxEncoderDecoderModel (编码器解码器模型)longt5— FlaxLongT5ForConditionalGeneration (LongT5 模型)marian— FlaxMarianMTModel (Marian 模型)mbart— FlaxMBartForConditionalGeneration (mBART 模型)mt5— FlaxMT5ForConditionalGeneration (MT5 模型)pegasus— FlaxPegasusForConditionalGeneration (Pegasus 模型)t5— FlaxT5ForConditionalGeneration (T5 模型)
示例:
>>> from transformers import AutoConfig, FlaxAutoModelForSeq2SeqLM >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> # Update configuration during loading >>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("t5-base", output_attentions=True) >>> model.config.output_attentions True >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) >>> config = AutoConfig.from_pretrained("./pt_model/t5_pt_model_config.json") >>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained( ... "./pt_model/t5_pytorch_model.bin", from_pt=True, config=config ... )
AutoModelForSequenceClassification
class transformers.AutoModelForSequenceClassification
( *args **kwargs )
这是一个通用的模型类,当使用 class method 或 class method 创建时,将实例化为库的模型类之一(带有序列分类头)。
这个类不能直接使用__init__()实例化(会抛出错误)。
from_config
( **kwargs )
参数
config(PretrainedConfig)— 选择要实例化的模型类基于配置类:
- AlbertConfig 配置类:AlbertForSequenceClassification(ALBERT 模型)
- BartConfig 配置类:BartForSequenceClassification(BART 模型)
- BertConfig 配置类:BertForSequenceClassification(BERT 模型)
- BigBirdConfig 配置类:BigBirdForSequenceClassification(BigBird 模型)
- BigBirdPegasusConfig 配置类:BigBirdPegasusForSequenceClassification(BigBird-Pegasus 模型)
- BioGptConfig 配置类:BioGptForSequenceClassification(BioGpt 模型)
- BloomConfig 配置类:BloomForSequenceClassification(BLOOM 模型)
- CTRLConfig 配置类:CTRLForSequenceClassification(CTRL 模型)
- CamembertConfig 配置类:CamembertForSequenceClassification(CamemBERT 模型)
- CanineConfig 配置类:CanineForSequenceClassification(CANINE 模型)
- ConvBertConfig 配置类:ConvBertForSequenceClassification(ConvBERT 模型)
- Data2VecTextConfig 配置类:Data2VecTextForSequenceClassification(Data2VecText 模型)
- DebertaConfig 配置类:DebertaForSequenceClassification(DeBERTa 模型)
- DebertaV2Config 配置类:DebertaV2ForSequenceClassification(DeBERTa-v2 模型)
- DistilBertConfig 配置类:DistilBertForSequenceClassification(DistilBERT 模型)
- ElectraConfig 配置类:ElectraForSequenceClassification(ELECTRA 模型)
- ErnieConfig 配置类:ErnieForSequenceClassification(ERNIE 模型)
- ErnieMConfig 配置类:ErnieMForSequenceClassification(ErnieM 模型)
- EsmConfig 配置类:EsmForSequenceClassification(ESM 模型)
- FNetConfig 配置类:FNetForSequenceClassification(FNet 模型)
- FalconConfig 配置类:FalconForSequenceClassification(Falcon 模型)
- FlaubertConfig 配置类:FlaubertForSequenceClassification(FlauBERT 模型)
- FunnelConfig 配置类:FunnelForSequenceClassification(Funnel Transformer 模型)
- GPT2Config 配置类:GPT2ForSequenceClassification(OpenAI GPT-2 模型)
- GPTBigCodeConfig 配置类:GPTBigCodeForSequenceClassification(GPTBigCode 模型)
- GPTJConfig 配置类:GPTJForSequenceClassification(GPT-J 模型)
- GPTNeoConfig 配置类:GPTNeoForSequenceClassification(GPT Neo 模型)
- GPTNeoXConfig 配置类: GPTNeoXForSequenceClassification (GPT NeoX 模型)
- IBertConfig 配置类: IBertForSequenceClassification (I-BERT 模型)
- LEDConfig 配置类: LEDForSequenceClassification (LED 模型)
- LayoutLMConfig 配置类: LayoutLMForSequenceClassification (LayoutLM 模型)
- LayoutLMv2Config 配置类: LayoutLMv2ForSequenceClassification (LayoutLMv2 模型)
- LayoutLMv3Config 配置类: LayoutLMv3ForSequenceClassification (LayoutLMv3 模型)
- LiltConfig 配置类: LiltForSequenceClassification (LiLT 模型)
- LlamaConfig 配置类: LlamaForSequenceClassification (LLaMA 模型)
- LongformerConfig 配置类: LongformerForSequenceClassification (Longformer 模型)
- LukeConfig 配置类: LukeForSequenceClassification (LUKE 模型)
- MBartConfig 配置类: MBartForSequenceClassification (mBART 模型)
- MPNetConfig 配置类: MPNetForSequenceClassification (MPNet 模型)
- MT5Config 配置类: MT5ForSequenceClassification (MT5 模型)
- MarkupLMConfig 配置类: MarkupLMForSequenceClassification (MarkupLM 模型)
- MegaConfig 配置类: MegaForSequenceClassification (MEGA 模型)
- MegatronBertConfig 配置类: MegatronBertForSequenceClassification (Megatron-BERT 模型)
- MistralConfig 配置类: MistralForSequenceClassification (Mistral 模型)
- MixtralConfig 配置类: MixtralForSequenceClassification (Mixtral 模型)
- MobileBertConfig 配置类: MobileBertForSequenceClassification (MobileBERT 模型)
- MptConfig 配置类: MptForSequenceClassification (MPT 模型)
- MraConfig 配置类: MraForSequenceClassification (MRA 模型)
- MvpConfig 配置类: MvpForSequenceClassification (MVP 模型)
- NezhaConfig 配置类: NezhaForSequenceClassification (Nezha 模型)
- NystromformerConfig 配置类: NystromformerForSequenceClassification (Nyströmformer 模型)
- OPTConfig 配置类: OPTForSequenceClassification (OPT 模型)
- OpenAIGPTConfig 配置类: OpenAIGPTForSequenceClassification (OpenAI GPT 模型)
- OpenLlamaConfig 配置类: OpenLlamaForSequenceClassification (OpenLlama 模型)
- PLBartConfig 配置类: PLBartForSequenceClassification (PLBart 模型)
- PerceiverConfig 配置类: PerceiverForSequenceClassification (Perceiver 模型)
- PersimmonConfig 配置类: PersimmonForSequenceClassification (Persimmon 模型)
- PhiConfig 配置类:PhiForSequenceClassification(Phi 模型)
- QDQBertConfig 配置类:QDQBertForSequenceClassification(QDQBert 模型)
- Qwen2Config 配置类:Qwen2ForSequenceClassification(Qwen2 模型)
- ReformerConfig 配置类:ReformerForSequenceClassification(Reformer 模型)
- RemBertConfig 配置类:RemBertForSequenceClassification(RemBERT 模型)
- RoCBertConfig 配置类:RoCBertForSequenceClassification(RoCBert 模型)
- RoFormerConfig 配置类:RoFormerForSequenceClassification(RoFormer 模型)
- RobertaConfig 配置类:RobertaForSequenceClassification(RoBERTa 模型)
- RobertaPreLayerNormConfig 配置类:RobertaPreLayerNormForSequenceClassification(RoBERTa-PreLayerNorm 模型)
- SqueezeBertConfig 配置类:SqueezeBertForSequenceClassification(SqueezeBERT 模型)
- T5Config 配置类:T5ForSequenceClassification(T5 模型)
- TapasConfig 配置类:TapasForSequenceClassification(TAPAS 模型)
- TransfoXLConfig 配置类:TransfoXLForSequenceClassification(Transformer-XL 模型)
- UMT5Config 配置类:UMT5ForSequenceClassification(UMT5 模型)
- XLMConfig 配置类:XLMForSequenceClassification(XLM 模型)
- XLMRobertaConfig 配置类:XLMRobertaForSequenceClassification(XLM-RoBERTa 模型)
- XLMRobertaXLConfig 配置类:XLMRobertaXLForSequenceClassification(XLM-RoBERTa-XL 模型)
- XLNetConfig 配置类:XLNetForSequenceClassification(XLNet 模型)
- XmodConfig 配置类:XmodForSequenceClassification(X-MOD 模型)
- YosoConfig 配置类:YosoForSequenceClassification(YOSO 模型)
从配置实例化库中的一个模型类(带有序列分类头)。
注意:从配置文件加载模型不会加载模型权重。它只影响模型的配置。使用 from_pretrained()来加载模型权重。
示例:
>>> from transformers import AutoConfig, AutoModelForSequenceClassification >>> # Download configuration from huggingface.co and cache. >>> config = AutoConfig.from_pretrained("bert-base-cased") >>> model = AutoModelForSequenceClassification.from_config(config)
from_pretrained
( *model_args **kwargs )
参数
pretrained_model_name_or_path(str或os.PathLike)— 可以是:
- 一个字符串,预训练模型的模型 ID,托管在 huggingface.co 上的模型存储库内。有效的模型 ID 可以位于根级别,如
bert-base-uncased,或者命名空间在用户或组织名称下,如dbmdz/bert-base-german-cased。 - 一个包含使用 save_pretrained()保存的模型权重的目录的路径,例如,
./my_model_directory/。 - 一个TensorFlow 索引检查点文件的路径或 URL(例如,
./tf_model/model.ckpt.index)。在这种情况下,from_tf应设置为True,并且应将配置对象提供为config参数。这种加载路径比使用提供的转换脚本将 TensorFlow 检查点转换为 PyTorch 模型并加载 PyTorch 模型要慢。
model_args(额外的位置参数,可选)— 将传递给底层模型__init__()方法。config(PretrainedConfig,可选)— 用于替代自动加载的配置的模型配置。当以下情况时,配置可以自动加载:
- 该模型是库提供的模型(使用预训练模型的模型 ID字符串加载)。
- 该模型是使用 save_pretrained()保存的,并通过提供保存目录来重新加载。
- 通过提供本地目录作为
pretrained_model_name_or_path加载模型,并在目录中找到名为config.json的配置 JSON 文件。
state_dict(Dict[str, torch.Tensor],可选)— 用于替代从保存的权重文件加载的状态字典的状态字典。
如果要从预训练配置创建模型但加载自己的权重,则可以使用此选项。但在这种情况下,您应该检查使用 save_pretrained()和 from_pretrained()是否不是更简单的选项。cache_dir(stroros.PathLike, optional) — 下载的预训练模型配置应缓存在其中的目录路径,如果不使用标准缓存。from_tf(bool, optional, defaults toFalse) — 从 TensorFlow 检查点保存文件加载模型权重(请参阅pretrained_model_name_or_path参数的文档字符串)。force_download(bool, optional, defaults toFalse) — 是否强制(重新)下载模型权重和配置文件,覆盖缓存版本(如果存在)。resume_download(bool, optional, defaults toFalse) — 是否删除接收不完整的文件。如果存在这样的文件,将尝试恢复下载。proxies(Dict[str, str], optional) — 用于每个请求的代理服务器的协议或端点的字典,例如,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}。代理将用于每个请求。output_loading_info(bool,optional, defaults toFalse) — 是否还返回包含缺失键、意外键和错误消息的字典。local_files_only(bool,optional, defaults toFalse) — 是否仅查看本地文件(例如,不尝试下载模型)。revision(str, optional, defaults to"main") — 要使用的特定模型版本。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision可以是 git 允许的任何标识符。trust_remote_code(bool, optional, defaults toFalse) — 是否允许在 Hub 上定义自定义模型的建模文件。此选项应仅在您信任的存储库中设置为True,并且您已阅读了代码,因为它将在本地机器上执行 Hub 上存在的代码。code_revision(str, optional, defaults to"main") — 用于 Hub 上代码的特定修订版本,如果代码位于与模型其余部分不同的存储库中。它可以是分支名称、标签名称或提交 ID,因为我们使用基于 git 的系统在 huggingface.co 上存储模型和其他工件,所以revision可以是 git 允许的任何标识符。kwargs(额外的关键字参数,可选) — 可用于更新配置对象(加载后)并启动模型(例如,output_attentions=True)。根据是否提供了config,行为会有所不同:
- 如果提供了
config,**kwargs将直接传递给底层模型的__init__方法(我们假设所有相关的配置更新已经完成) - 如果未提供配置,
kwargs将首先传递给配置类初始化函数(from_pretrained())。与配置属性对应的kwargs的每个键将用提供的kwargs值覆盖该属性。不对应任何配置属性的剩余键将传递给底层模型的__init__函数。
从预训练模型实例化库中的一个模型类(带有序列分类头)。
要实例化的模型类是根据配置对象的 model_type 属性选择的(如果可能,作为参数传递或从 pretrained_model_name_or_path 加载),或者当缺少时,通过在 pretrained_model_name_or_path 上使用模式匹配来回退:
albert— AlbertForSequenceClassification (ALBERT 模型)bart— BartForSequenceClassification (BART 模型)bert— BertForSequenceClassification (BERT 模型)big_bird— BigBirdForSequenceClassification (BigBird 模型)bigbird_pegasus— BigBirdPegasusForSequenceClassification (BigBird-Pegasus 模型)biogpt— BioGptForSequenceClassification (BioGpt 模型)bloom— BloomForSequenceClassification (BLOOM 模型)camembert— CamembertForSequenceClassification (CamemBERT 模型)canine— CanineForSequenceClassification (CANINE 模型)code_llama— LlamaForSequenceClassification (CodeLlama 模型)convbert— ConvBertForSequenceClassification (ConvBERT 模型)ctrl— CTRLForSequenceClassification (CTRL 模型)data2vec-text— Data2VecTextForSequenceClassification (Data2VecText 模型)deberta— DebertaForSequenceClassification (DeBERTa 模型)deberta-v2— DebertaV2ForSequenceClassification (DeBERTa-v2 模型)distilbert— DistilBertForSequenceClassification (DistilBERT 模型)electra— ElectraForSequenceClassification (ELECTRA 模型)ernie— ErnieForSequenceClassification (ERNIE 模型)ernie_m— ErnieMForSequenceClassification (ErnieM 模型)esm— EsmForSequenceClassification (ESM 模型)falcon— FalconForSequenceClassification (Falcon 模型)flaubert— FlaubertForSequenceClassification (FlauBERT 模型)fnet— FNetForSequenceClassification (FNet 模型)funnel— FunnelForSequenceClassification (Funnel Transformer model)gpt-sw3— GPT2ForSequenceClassification (GPT-Sw3 model)gpt2— GPT2ForSequenceClassification (OpenAI GPT-2 model)gpt_bigcode— GPTBigCodeForSequenceClassification (GPTBigCode model)gpt_neo— GPTNeoForSequenceClassification (GPT Neo model)gpt_neox— GPTNeoXForSequenceClassification (GPT NeoX model)gptj— GPTJForSequenceClassification (GPT-J model)ibert— IBertForSequenceClassification (I-BERT model)layoutlm— LayoutLMForSequenceClassification (LayoutLM model)layoutlmv2— LayoutLMv2ForSequenceClassification (LayoutLMv2 model)layoutlmv3— LayoutLMv3ForSequenceClassification (LayoutLMv3 model)led— LEDForSequenceClassification (LED model)lilt— LiltForSequenceClassification (LiLT model)llama— LlamaForSequenceClassification (LLaMA model)longformer— LongformerForSequenceClassification (Longformer model)luke— LukeForSequenceClassification (LUKE model)markuplm— MarkupLMForSequenceClassification (MarkupLM model)mbart— MBartForSequenceClassification (mBART model)mega— MegaForSequenceClassification (MEGA model)megatron-bert— MegatronBertForSequenceClassification (Megatron-BERT model)mistral— MistralForSequenceClassification (Mistral model)mixtral— MixtralForSequenceClassification (Mixtral model)mobilebert— MobileBertForSequenceClassification (MobileBERT model)mpnet— MPNetForSequenceClassification (MPNet model)mpt— MptForSequenceClassification (MPT model)mra— MraForSequenceClassification (MRA 模型)mt5— MT5ForSequenceClassification (MT5 模型)mvp— MvpForSequenceClassification (MVP 模型)nezha— NezhaForSequenceClassification (Nezha 模型)nystromformer— NystromformerForSequenceClassification (Nyströmformer 模型)open-llama— OpenLlamaForSequenceClassification (OpenLlama 模型)openai-gpt— OpenAIGPTForSequenceClassification (OpenAI GPT 模型)opt— OPTForSequenceClassification (OPT 模型)perceiver— PerceiverForSequenceClassification (Perceiver 模型)persimmon— PersimmonForSequenceClassification (Persimmon 模型)phi— PhiForSequenceClassification (Phi 模型)plbart— PLBartForSequenceClassification (PLBart 模型)qdqbert— QDQBertForSequenceClassification (QDQBert 模型)qwen2— Qwen2ForSequenceClassification (Qwen2 模型)reformer— ReformerForSequenceClassification (Reformer 模型)rembert— RemBertForSequenceClassification (RemBERT 模型)roberta— RobertaForSequenceClassification (RoBERTa 模型)roberta-prelayernorm— RobertaPreLayerNormForSequenceClassification (RoBERTa-PreLayerNorm 模型)roc_bert— RoCBertForSequenceClassification (RoCBert 模型)roformer— RoFormerForSequenceClassification (RoFormer 模型)squeezebert— SqueezeBertForSequenceClassification (SqueezeBERT 模型)t5— T5ForSequenceClassification (T5 模型)tapas— TapasForSequenceClassification (TAPAS 模型)transfo-xl— TransfoXLForSequenceClassification (Transformer-XL 模型)umt5— UMT5ForSequenceClassification (UMT5 模型)xlm— XLMForSequenceClassification (XLM 模型)xlm-roberta— XLMRobertaForSequenceClassification (XLM-RoBERTa 模型)xlm-roberta-xl— XLMRobertaXLForSequenceClassification (XLM-RoBERTa-XL 模型)xlnet— XLNetForSequenceClassification (XLNet 模型)xmod— XmodForSequenceClassification (X-MOD 模型)yoso— YosoForSequenceClassification (YOSO 模型)
默认情况下,该模型处于评估模式,使用 model.eval()(例如,dropout 模块被停用)。要训练模型,您应该首先使用 model.train() 将其设置回训练模式
示例:
>>> from transformers import AutoConfig, AutoModelForSequenceClassification >>> # Download model and configuration from huggingface.co and cache. >>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased") >>> # Update configuration during loading >>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", output_attentions=True) >>> model.config.output_attentions True >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) >>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json") >>> model = AutoModelForSequenceClassification.from_pretrained( ... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config ... )
Transformers 4.37 中文文档(十三)(9)https://developer.aliyun.com/article/1564954