欢迎访问宙启技术站
智能推送

如何在Python中使用from_pretrained()加载预训练的文本分类模型

发布时间:2024-01-03 19:44:35

在Python中,使用from_pretrained()方法可以加载预训练的文本分类模型。这个方法是在transformers库中的AutoModelForSequenceClassification类中定义的。下面是一个使用预训练模型来进行文本分类的示例代码:

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# 加载预训练的tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')

# 输入文本
text = "This is an example sentence."

# 进行tokenize和padding处理
inputs = tokenizer.encode_plus(
    text,
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors='pt'
)

# 运行模型
outputs = model(**inputs)

# 获取预测结果
predictions = outputs.logits.argmax(dim=1)

# 打印结果
print('Predicted label:', predictions.item())

上面的代码中,我们首先使用AutoTokenizer.from_pretrained()方法加载预训练的tokenizer。这里使用的是bert-base-uncased模型,它是一个基于BERT的文本分类模型,其中的uncased表示模型没有区分大小写。

接下来,我们使用AutoModelForSequenceClassification.from_pretrained()方法加载预训练的文本分类模型。这个方法会自动下载并加载模型的权重。

然后,我们定义了一个输入文本text,我们要对这个文本进行分类。

接着,我们使用tokenizer的encode_plus()方法对输入文本进行tokenize和padding处理,以便与模型的输入格式相匹配。在这个例子中,我们将输入文本的长度限制为128,并指定padding的方式为max_length

最后,我们将处理后的输入数据传递给模型,并获取模型的输出结果。输出结果是一个张量(tensor),其中的元素表示不同标签的得分。

最后,我们使用argmax()方法获取得分最高的标签,并打印预测结果。

需要注意的是,from_pretrained()方法在首次使用时会自动下载模型的权重和配置文件。因此,第一次运行代码可能需要一些时间来完成下载。之后,模型的权重和配置文件将被缓存,不需要再次下载。