CPM-Distill:经过知识蒸馏的小型文本生成模型


本文介绍知识蒸馏技术及基于PaddleNLP加载CPM-Distill模型实现文本生成。知识蒸馏是模型压缩方法,以“教师-学生网络”思想,让简单模型拟合复杂模型输出,效果优于从头训练。CPM-Distill由GPT-2 Large蒸馏得到,文中还给出安装依赖、加载模型、解码方法及文本生成示例。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

引入

  • 近些年来,随着 Bert 这样的大规模预训练模型的问世,NLP 领域的模型也逐渐变得越来越大了
  • 受限于算力水平,如此大规模的模型要应用在实际的部署场景都是不太实际的
  • 因此需要通过一些方式对大规模的模型进行压缩,使其能够在部署场景下达到一个相对可用的速度
  • 常见的模型压缩方法有:剪枝、量化、知识蒸馏等
  • 最近 CPM(Chinese Pre-Trained Models)项目又开源了一个使用知识蒸馏得到的小型文本生成模型 CPM-Distill
  • 本次项目就简单介绍一下知识蒸馏技术并且通过 PaddleNLP 套件加载 CPM-Distill 模型实现文本生成

相关项目

  • Paddle2.0:构建一个经典的文本生成模型GPT-2
  • 文本生成:使用GPT-2加载CPM-LM模型实现简单的问答机器人
  • 文本生成:让AI帮你写文章吧
  • 【AI创造营】PaddleHub 配合 PaddleNLP 实现简单的文本生成

相关资料

  • 论文:
    • CPM: A Large-scale Generative Chinese Pre-trained Language Model
    • Distilling the Knowledge in a Neural Network
  • 官方实现:TsinghuaAI/CPM-Distill

模型压缩技术

知识蒸馏(Knowledge Distillation)

  • 知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法。

  • 由 Hinton 在 2015 年 Distilling the Knowledge in a Neural Network 的论文首次提出了知识蒸馏的并尝试在 CV 领域中使用,旨在把大模型学到的知识灌输到小模型中,以达到缩小模型的目标,示意图如下:

  • 说人话就是指用一个简单模型去拟合复杂模型的输出,这个输出也叫做“软标签”,当然也可以加入真实数据作为“硬标签”一同训练。
  • 使用知识蒸馏技术相比直接从头训练的效果一般会更好一些,因为教师模型能够指导学生模型收敛到一个更佳的位置。

  • 知识蒸馏技术除了可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能;
  • 也可以将多个网络的学到的知识转移到一个网络中,使得单个网络的性能接近 emsemble 的结果。

蒸馏模型信息

  • 教师模型为 GPT-2 Large,具体的模型参数如下:
teacher_model = GPTModel(
    vocab_size=30000,
    hidden_size=2560,
    num_hidden_layers=32,
    num_attention_heads=32,
    intermediate_size=10240,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)
  • 学生模型为 GPT-2 Small,具体的模型参数如下:
teacher_model = GPTModel(
    vocab_size=30000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)

蒸馏 loss

  • 将大模型和小模型每个位置上输出之间的 KL 散度作为蒸馏 loss,同时加上原来的 language model loss。总 loss 如下:

其中 LlmLlm 为 GPT-2 原始的 language modeling loss。

安装依赖

In [ ]
!pip install paddlenlp==2.0.1 sentencepiece==0.1.92

加载模型

In [1]
import paddlefrom paddlenlp.transformers import GPTModel, GPTForPretraining, GPTChineseTokenizer# tokenizer 与 CPM-LM 模型一致tokenizer = GPTChineseTokenizer.from_pretrained('gpt-cpm-large-cn')# 实例化 GPT2-small 模型gpt = GPTModel(
    vocab_size=30000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=1024,
    type_vocab_size=1,
    initializer_range=0.02,
    pad_token_id=0,
    topo=None)# 加载预训练模型参数params = paddle.load('data/data92160/gpt-cpm-small-cn-distill.pdparams')# 设置参数gpt.set_dict(params)# 使用 GPTForPretraining 向模型中添加输出层model = GPTForPretraining(gpt)# 将模型设置为评估模式model.eval()
[2025-05-28 19:38:04,469] [    INFO] - Found /home/aistudio/.paddlenlp/models/gpt-cpm-large-cn/gpt-cpm-cn-sentencepiece.model

模型解码

In [40]
import paddleimport numpy as np# Greedy Searchdef greedy_search(text, max_len=32, end_word=None):
    # # 终止标志
    if end_word is not None:
        stop_id = tokenizer.encode(end_word)['input_ids']
        length = len(stop_id)    else:
        stop_id = [tokenizer.eod_token_id]
        length = len(stop_id)    
    # 初始预测
    ids = tokenizer.encode(text)['input_ids']
    input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
    output, cached_kvs = model(input_id, use_cache=True)
    next_token = int(np.argmax(output[0, -1].numpy()))
    ids.append(next_token)    # 使用缓存进行继续预测
    for i in range(max_len-1):
        input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
        output, cached_kvs = model(input_id, use_cache=True, cache=cached_kvs)
        next_token = int(np.argmax(output[0, -1].numpy()))
        ids.append(next_token)        # 根据终止标志停止预测
        if ids[-length:]==stop_id:            if end_word is None:
               ids = ids[:-1]            break
    
    return tokenizer.convert_ids_to_string(ids)
In [39]
import paddleimport numpy as np# top_k and top_p filteringdef top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.shape[-1])  # Safety check
    logits_np = logits.numpy()    if top_k > 0:        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits_np < np.sort(logits_np)[-top_k]
        logits_np[indices_to_remove] = filter_value    if top_p < 1.0:
        sorted_logits = paddle.sort(logits, descending=True)
        sorted_indices = paddle.argsort(logits, descending=True).numpy()
        cumulative_probs = paddle.cumsum(paddle.nn.functional.softmax(sorted_logits, axis=-1), axis=-1).numpy()        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits_np[indices_to_remove] = filter_value    return paddle.to_tensor(logits_np)# Nucleus Sampledef nucleus_sample(text, max_len=32, end_word=None, repitition_penalty=1.0, temperature=1.0, top_k=0, top_p=1.0):
    # 终止标志
    if end_word is not None:
        stop_id = tokenizer.encode(end_word)['input_ids']
        length = len(stop_id)    else:
        stop_id = [tokenizer.eod_token_id]
        length = len(stop_id)    # 初始预测
    ids = tokenizer.encode(text)['input_ids']
    input_id = paddle.to_tensor(np.array(ids).reshape(1, -1).astype('int64'))
    output, cached_kvs = model(input_id, use_cache=True)
    next_token_logits = output[0, -1, :]    for id in set(ids):
        next_token_logits[id] /= repitition_penalty
    next_token_logits = next_token_logits / temperature
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
    ids += [int(next_token)]    # 使用缓存进行继续预测
    for i in range(max_len-1):
        input_id = paddle.to_tensor(np.array([next_token]).reshape(1, -1).astype('int64'))
        output, cached_kvs = model(input_id, use_cache=True, cache=cached_kvs)
        next_token_logits = output[0, -1, :]        for id in set(ids):
            next_token_logits[id] /= repitition_penalty
        next_token_logits = next_token_logits / temperature
        filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
        next_token = paddle.multinomial(paddle.nn.functional.softmax(filtered_logits, axis=-1), num_samples=1).numpy()
        ids += [int(next_token)]        # 根据终止标志停止预测
        if ids[-length:]==stop_id:            if end_word is None:
               ids = ids[:-1]            break

    return tokenizer.convert_ids_to_string(ids)

文本生成

In [41]
# 输入文本inputs = input('请输入文本:')print(inputs)# 使用 Nucleus Sample 进行文本生成outputs = greedy_search(
    inputs, # 输入文本
    max_len=128, # 最大生成文本的长度
    end_word=None)# 打印输出print(outputs)
请输入文本:请在此处输入你的姓名
请在此处输入你的姓名,然后点击“确定”,就可以开始游戏了。
游戏目标:在限定时间内,成功地把所有的牌都通通打完。
In [43]
# 输入文本inputs = input('请输入文本:')print(inputs)for x in range(5):    # 使用 Nucleus Sample 进行文本生成
    outputs = nucleus_sample(
        inputs, # 输入文本
        max_len=128, # 最大生成文本的长度
        end_word='。', # 终止符号
        repitition_penalty=1.0, # 重复度抑制
        temperature=1.0, # 温度
        top_k=3000, # 取前k个最大输出再进行采样
        top_p=0.9 # 抑制概率低于top_p的输出再进行采样
    )    # 打印输出
    print(outputs)
请输入文本:请在此处输入你的姓名
请在此处输入你的姓名、学校、专业及学科,并在社交媒体上公布你的个人简介。
请在此处输入你的姓名或者电话,对方会及时通知你。
请在此处输入你的姓名、民族及籍贯信息,当您找到 CADULI 的联系方式后,我们会按您所选择的申请中心,以电子邮件的形式向您发送邮件。
请在此处输入你的姓名和电话号码,由资深*接待员进行介绍,因为此处有不少中国的大老板,英文能看。
请在此处输入你的姓名、联系电话、银行卡号和手机号。


# 请输入  # 并在  # 帮你  # 不太  # 多个  # 首次  # 都是  # 是一种  # 加载  # git  # 请在  # gpt  # nlp  # bert  # red  # 压缩技术  # ai  # cad 


相关栏目: 【 Google疑问12 】 【 Facebook疑问10 】 【 网络优化91478 】 【 技术知识72672 】 【 云计算0 】 【 GEO优化84317 】 【 优选文章0 】 【 营销推广36048 】 【 网络运营41350 】 【 案例网站102563 】 【 AI智能45237


相关推荐: 轻松生成二维码:免费AI工具终极指南  5分钟教你用AI将你的研究数据生成可视化的图表和摘要  百度输入法ai面板怎么关 百度输入法ai面板隐藏技巧  Gemini 辅助进行多平台社交媒体内容调度  Midjourney怎样用参数调色彩饱和度_Midjourney饱和度调整【方法】  AI赋能营销:角色、策略与工具选择全指南  怎么用ai做证件照换底色 AI一键抠图与背景色替换【方法】  AI面试作弊与反作弊:求职者与企业的博弈  Tamilnad Mercantile Bank TMB:如何在线下载账户报表  如何用文心一言写简历 快速生成高含金量求职简历方法  稿定设计AI抠图怎么修复瑕疵_稿定设计AI瑕疵修复与手动微调【步骤】  唐库AI拆书工具怎样设置拆书深度_唐库AI拆书工具深度调节与内容详略控制【技巧】  宝可梦朱紫:如何高效刷闪异色宝可梦,提升游戏体验  OpenArt:终极AI内容创作平台,图像、视频和角色一致性  Tune AI: 革新音乐创作,AI音乐平台深度测评  Comet浏览器:使用ChatGPT增强您的搜索体验  文本分类与聚类:网络安全中的自然语言处理应用  5分钟教你用AI给黑白老照片上色,让回忆变得鲜活  谷歌 Gemini AI 助手详解:功能、应用与隐私设置  Higgsfield WAN 2.5:AI视频生成工具新纪元  Mac百度输入法ai怎么关 Mac版百度ai助手禁用教程  百度ai助手任务栏怎么关 百度ai助手任务栏图标隐藏  ChatGPT一键生成PPT怎么加目录_ChatGPTPPT目录添加【步骤】  Speerise亮面体操服测评:舒适与时尚的完美结合  如何用AI生成正则表达式?再也不怕复杂的文本匹配  Descript音频编辑终极指南:技巧、AI工具与专业效果  Feelin网页版在线入口 Feelin官方网站导航  去哪旅行ai抢票助手怎样添加备选车次_去哪旅行ai抢票助手备选车次设置与切换【攻略】  斑马AI怎样设置专注模式_斑马AI专注时段与干扰屏蔽【指南】  壹伴AI智能排版如何自动生成文章配图_壹伴AI智能排版配图生成与版权说明【教程】  Shopify着陆页:用AI工具快速提升营销效果  京东旅行AI能否抢返程票_京东AI返程票预约与自动抢购【技巧】  AI写作鱼怎么一键生成论文大纲_AI写作鱼大纲生成与逻辑梳理【技巧】  探索占星术:揭秘 कुंडली 中的 शुक्र,财富与运势的钥匙  在线歌曲歌词生成器:创意歌词轻松创作指南  唐库AI拆书工具怎么查看拆书进度_唐库AI拆书工具进度查看与异常排查【方法】  AI虚拟女友:终极浪漫伴侣还是数字陷阱?  ChatGPT图像生成器完全指南:文化影响、伦理挑战与商业变革  律师视角下的生成式AI:信息爆炸时代的法律实践与未来展望  汽车“以旧换新”补贴升级:2026年置换最高补1.5万元  创客贴AI排版如何批量处理图文_创客贴AI排版批量操作与效率提升【方法】  探索泰勒·斯威夫特《August》的深层含义:歌词解析与情感分析  AI任务管理器终极评测:找到最适合你的效率神器  批改网AI检测工具怎样设置检测维度_批改网AI检测工具维度勾选与权重调整【技巧】  历史影像解密:唇语专家如何还原一战士兵对话?  歌曲分析:The Killers乐队的《Mr. Brightside》歌词深度解析  如何用AI帮你创作节日贺卡文案?让祝福与众不同  AI电子书写作终极指南:ChatGPT和Canva实战教程  ChatGPT一键生成PPT怎么加动画_ChatGPTPPT动画添加【指南】  AI社交媒体自动化:n8n与HeyGen打造个性化内容引擎 

 2025-07-18

了解您产品搜索量及市场趋势,制定营销计划

同行竞争及网站分析保障您的广告效果

点击免费数据支持

提交您的需求,1小时内享受我们的专业解答。

南京市珐之弘网络技术有限公司


南京市珐之弘网络技术有限公司

南京市珐之弘网络技术有限公司专注海外推广十年,是谷歌推广.Facebook广告全球合作伙伴,我们精英化的技术团队为企业提供谷歌海外推广+外贸网站建设+网站维护运营+Google SEO优化+社交营销为您提供一站式海外营销服务。

 87067657

 13565296790

 87067657@qq.com

Notice

We and selected third parties use cookies or similar technologies for technical purposes and, with your consent, for other purposes as specified in the cookie policy.
You can consent to the use of such technologies by closing this notice, by interacting with any link or button outside of this notice or by continuing to browse otherwise.