深入理解glm运作机制

深入理解GLM运作机制

先熟悉熟悉glm的代码吧,包括数据读取,模型结构,解码方式 理论部分见前面博客写的GLM论文解析

背景

官方网址 首先GLM是一个自回归模型 先下载预训练权重 官方共提供 10个(不包含千亿模型)

  • GLM-Base 110M
  • GLM-Large 335M
  • GLM-Large-Chinese 335M
  • GLM-Doc 335M
  • GLM-410M
  • GLM-515M
  • GLM-RoBERTa 335M
  • GLM-2B
  • GLM-10B
  • GLM-10B-Chinese

关于环境后期补一个关于Docker的说明emmmm (先鸽了)

使用

在huggingFace快速体验 huggingface使用教程后期补 (先鸽了)

生成

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("BAAI/glm-10b", trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained("BAAI/glm-10b", trust_remote_code=True)
model = model.half().cuda()
model.eval()

# 推理
inputs = tokenizer("Ng is an adjunct professor at [MASK] (formerly associate professor and Director of its Stanford AI Lab or SAIL ). Also a pioneer in online education, Ng co-founded Coursera and deeplearning.ai.", return_tensors="pt")
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
inputs = inputs.to('cuda')
outputs = model.generate(**inputs, max_length=512, eos_token_id=tokenizer.eop_token_id)
print(tokenizer.decode(outputs[0].tolist()))

# 训练
inputs = tokenizer(
    ["Tsinghua University is located in [MASK].", "One minus one equals zero, is it correct? Answer: [MASK]"],
    return_tensors="pt", padding=True)
inputs = tokenizer.build_inputs_for_generation(inputs, targets=["Beijing", "No"], max_gen_length=8)
inputs = inputs.to('cuda')
outputs = model(**inputs)
loss = outputs.loss
logits = outputs.logits

分类

from transformers import AutoTokenizer, AutoModelForMultipleChoice
tokenizer = AutoTokenizer.from_pretrained("BAAI/glm-10b", trust_remote_code=True)
model = AutoModelForMultipleChoice.from_pretrained("BAAI/glm-10b", trust_remote_code=True)
model = model.half().cuda()
model.eval()

inputs = tokenizer(["Tsinghua University is located in [MASK].",
                    "One minus one equals zero, is it correct? Answer: [MASK]"], return_tensors="pt", padding=True)
choices = [["Beijing", "Shanghai"], ["Yes", "No"]]
inputs = tokenizer.build_inputs_for_multiple_choice(inputs, choices)
inputs = inputs.to('cuda')
outputs = model(**inputs)
logits = outputs.logits

官方提供了脚本供微调代码以适应下游任务

从左到右的生成 /空白填充(交互式)

改变CHECKPOINT_PATH

bash scripts/generate_block.sh \
     config_tasks/model_blocklm_10B_chinese.sh

不同模型使用不同的MASK token [MASK] 短文本填充 [sMASK] 句子填充 ` [gMASK] 自左到右生成 每个空白答案以<|startofpiece|>`开始

SuperGLUE

修改DATA_ROOT, CHECKPOINT_PATH, SAVE_PATH 同时需要根据自身硬件修改bsnproc_per_node 以COPA数据集为例

bash scripts/ds_finetune_superglue.sh \
     config_tasks/model_blocklm_10B.sh \
     config_tasks/task_copa.sh
  • P-tuning
bash scripts/ds_finetune_superglue_prompt.sh \
     config_tasks/model_blocklm_10B.sh \
     config_tasks/task_copa.sh
  • 去应用GLM在新的NLU数据集(冻结微调)执行DataProcessortasks/superglue/dataset.py载入数据,加入PVPtasks/superglue/pvp.py锁定问题

S2S任务

修改DATA_ROOT, CHECKPOINT_PATH, SAVE_PATH 使用CNN/Daily数据集为例

bash scripts/ds_finetune_seq2seq.sh \ 
   config_tasks/model_blocklm_10B.sh \ 
   config_tasks/seq_cnndm_org.sh

摘要写入./runs/experiment_name/test.jsonl.hyps 参考写入test.jsonl.refs 如果要计算rouge,安装file2rouge并下载 Stanford CoreNLP。

bash scripts/evaluate_seq2seq.sh \
 ./runs/experiment_name/test.jsonl.hyps ./runs/experiment_name/test.jsonl.refs

训练自己的数据 加载你自己的数据在{split}.source{split}.target 每一行是一个内容,划分为train,valtest

bash scripts/ds_finetune_seq2seq.sh \ 
   config_tasks/model_blocklm_10B.sh \ 
   config_tasks/seq_customization.sh

config_tasks/seq_customization.shconfig_tasks/config_blocklm_10B_cnndm.json中确认超参数

多项选择(Zero-shot)

修改CHECKPOINT_PATHDATA_PATH

bash scripts/evaluate_multichoice.sh config_tasks/model_blocklm_10B.sh

数据文件格式应该像下面所示

{"inputs_pretokenized": "Context and question here", "choices_pretokenized": ["Choice 1", "Choice 2", "Choice 3"], "label": int}

语言模型

  • LAMBADA完型填空精度

下载LAMBADA数据,修改DATA_ROOT, CHECKPOINT_PATH在scripts/evaluate_lm.sh

bash scripts/evaluate_lm.sh \ 
     config_tasks/model_blocklm_large_generation.sh \
     config_tasks/zero_lambada.sh 
  • LM Perplexity

下载 test set of wikibook 或者 Wikitext103数据集,修改DATA_ROOT, CHECKPOINT_PATH在scripts/evaluate_lm.sh

bash scripts/evaluate_lm.sh \ 
   config_tasks/model_blocklm_large_generation.sh \
   config_tasks/zero_wikitext.sh 

文本填空

  • 下载Yahoo数据集,修改DATA_ROOT, CHECKPOINT_PATH, SAVE_PATH
bash scripts/finetune_blank.sh \ 
     config_tasks/model_blocklm_large.sh \ 
     config_tasks/seq_blank.sh

预训练

bash scripts/ds_pretrain_nvidia.sh config/ds_block_large.sh
  • scripts/ds_pretrain_nvidia.sh

修改NUM_WORKERSNUM_GPUS_PER_WORKER. 同时修改HOST_FILE_PATH(OpenMPI-style hostfile) 后面更一期DeepSpeed使用说明emmmm (先鸽了)

  • config/ds_block_large.sh

定义了超参数 --train-data可以在NAMED_CORPORA(data_utils/corpora.py)多关键词定义。 超参优化器在相应的config文件下(json)

数据读取(模型结构,解码方式)

从源码角度理解

官方简单填空生成角度(这部分主要是推理不涉及训练)

该样例没有用到deepspeed。简单的DDP启动 首先有两个脚本文件

  • generate_block.sh 这里设置任务相关参数
    • 使用显卡
    • 模型权重路径(cp点)
    • MPSIZE
    • MAXSEQLEN
    • MASTER_PORT
    • TEMP
    • TOPK
    • TOPP
    • config_json
  • model_blocklm_10B_chinese.sh 这里设置模型相关参数

    懒得解释了,基本就是模型层数,编码,加载路径之类的。

  • 核心执行文件generate_samples.py 先放段10B加载例子

这是个交互模型,简单来说没有太多应用级别的封装。 下面看源码 这种最简单的实例引用的包就比较贴近纯原生pytorch

除去主函数一共六个功能函数

  • def setup_model(args) ——按pytorch官方教程重写
  • def get_batch(context_tokens, device, args): ——按pytorch官方教程重写
  • def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float(‘Inf’)):
  • def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None):
  • def read_context(tokenizer, args, output):
  • def read_context(tokenizer, args, output):
  • def generate_samples(model, tokenizer, args, device):

先从主函数代码解析

def main():
    """Main training program."""

    print('Generate Samples')

    # Disable CuDNN.禁用CuDNN
    torch.backends.cudnn.enabled = False

    # Arguments.获取参数
    args = get_args()
    args.mem_length = args.seq_length + args.mem_length - 1

    # Pytorch distributed.分布式初始化
    initialize_distributed(args)

    # Random seeds for reproducability.设置随机种子
    set_random_seed(args.seed)

    # get the tokenizer设置tokenizer编码器
    tokenizer = prepare_tokenizer(args)

    # Model, optimizer, and learning rate.设置模型,优化器,学习率
    model = setup_model(args)

    # setting default batch size to 1 默认bs=1,一般训练会开的很大。
    args.batch_size = 1

    # generate samples采样
    generate_samples(model, tokenizer, args, torch.cuda.current_device())

补充了一下中文注释,我觉得英文基本上说的很清楚每一步

核心要点看采样函数,其他模型大差不差。

采样函数

model.eval()
    output_path = "./samples"
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
    with torch.no_grad(), open(output_path, "w") as output:
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())

            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
            if terminate_runs == 1:
                return
            start_time = time.time()
            if args.block_lm:
                mems = []
                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
                end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
                mask_positions = []
                for token in mask_tokens:
                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
                mask_positions.sort()
                if args.no_block_position:
                    for mask_position in mask_positions:
                        position_ids[0, mask_position + 1:] += args.out_seq_length
                _, *mems = model(tokens, position_ids, attention_mask, *mems)
                for mask_position in mask_positions:
                    if args.no_block_position:
                        position = position_ids[0, mask_position].item()
                    else:
                        position = mask_position
                    tokens, mems = sample_sequence(model, tokenizer, tokens, position,
                                                   args, device, mems=mems, end_tokens=end_tokens)
            else:
                tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device)
            output_tokens_list = tokens.view(-1).contiguous()
            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
                print("\nContext:", raw_text, flush=True)
                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
                trim_decode_tokens = decode_tokens
                print("\nGLM:", trim_decode_tokens, flush=True)
                output.write(trim_decode_tokens + "\n")

            torch.distributed.barrier(group=mpu.get_model_parallel_group())

解析 核心实在time()函数后面,前面基本正常推理代码。 引入mems储存前一步attention,减少计算 通过get_batch()获得tokens,attention_mask,position_ids(pytorch 官方教程) 设置mask_token和end_token 计算mask_positions并排序 采样序列计算token

读入函数


def read_context(tokenizer, args, output):
    terminate_runs, skip_run = 0, 0
    if mpu.get_model_parallel_rank() == 0:
        while True:
            raw_text = input("\nContext prompt (stop to exit) >>> ")
            if not raw_text:
                print('Prompt should not be empty!')
                continue
            if raw_text == "stop":
                terminate_runs = 1
                break
            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
            if args.block_lm and 'MASK]' not in raw_text:
                raw_text += ' ' + generation_mask
            output.write(raw_text)
            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
            if args.block_lm:
                context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
                if not raw_text.endswith('[gMASK]'):
                    context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
            context_length = len(context_tokens)

            if context_length >= args.seq_length:
                print("\nContext length", context_length,
                      "\nPlease give smaller context than the window length!")
                continue
            break
    else:
        context_length = 0

    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    terminate_runs = terminate_runs_tensor[0].item()

    if terminate_runs == 1:
        return terminate_runs, None, None, None

    context_length_tensor = torch.cuda.LongTensor([context_length])

    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    context_length = context_length_tensor[0].item()
    if mpu.get_model_parallel_rank() == 0:
        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    else:
        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    if mpu.get_model_parallel_rank() != 0:
        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
    return terminate_runs, raw_text, context_tokens_tensor, context_length
  • 输入tokerizer,参数,文件
  • 输出
    • terminate_runs
      1表示停止
    • raw_text 最终输入解码文本
    • context_tokens_tensor 编码文本token
    • context_length 编码文本token长度

主要对纯input输入增加了MASK]同时增加了判别保护机制

  • 编码 tokenizer.EncodeAsIds(raw_text).tokenization

  • 解码 tokenizer.DecodeIds(context_tokens_tensor.tolist())

序列采样

def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None):
    if not args.block_lm:
        context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args)
        tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long)
    else:
        tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id)
    counter = 0
    if mems is None:
        mems = []
    if end_tokens is None:
        end_tokens = [args.eod_token]
    if args.num_beams > 1:
        beam_scorer = BeamSearchScorer(
            batch_size=1,
            max_length=args.out_seq_length,
            num_beams=args.num_beams,
            device=context_tokens.device,
            length_penalty=args.length_penalty,
            do_early_stopping=False,
        )
        beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device)
    last_beam_num = 1
    while counter < args.out_seq_length:
        if counter == 0 and not args.block_lm:
            next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems)
        else:
            if args.block_lm:
                if args.no_block_position:
                    position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter)
                else:
                    position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
                    position_ids[:, 0] = context_length
                    position_ids[:, 1] = counter + 1
                attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long)
            else:
                position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
                attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
                                                         device=context_tokens.device, dtype=torch.float)
            last_token = tokens[:, -1:]
            next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems)
        next_token_logits = next_token_logits[:, -1]
        if args.num_beams > 1:
            next_token_scores = F.log_softmax(next_token_logits, dim=-1)
            next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
            vocab_size = next_token_scores.shape[-1]
            next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)

            probs = F.softmax(next_token_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams)
            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, _indices)

            next_indices = next_tokens // vocab_size
            next_tokens = next_tokens % vocab_size
            # stateless
            tokens = tokens.expand((args.num_beams, -1))
            beam_outputs = beam_scorer.process(
                tokens,
                next_token_scores,
                next_tokens,
                next_indices,
                eos_token_id=end_tokens,
                mems=mems
            )
            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]
            beam_next_tokens = beam_next_tokens.unsqueeze(-1)
            tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
            mems = [mem[beam_idx] for mem in mems] if mems else None
            if beam_scorer.is_done:
                break
            last_beam_num = args.num_beams
        else:
            next_token_logits /= args.temperature
            next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p)
            log_probs = F.softmax(next_token_logits, dim=-1)
            prev = torch.multinomial(log_probs, num_samples=1)[0]
            is_end = prev.item() in end_tokens
            if is_end:
                break
            prev = prev.view(1, 1)
            tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
        counter += 1
        if not args.block_lm and mpu.get_model_parallel_rank() == 0 and counter % 16 == 0:
            output_tokens_list = tokens.view(-1).contiguous()
            decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
            if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end):
                os.system('clear')
                trim_decode_tokens = decode_tokens
                print(trim_decode_tokens, flush=True)
    if args.num_beams > 1:
        tokens, mems, _ = beam_scorer.finalize(tokens, beam_scores, next_tokens, next_indices, eos_token_id=args.eod_token,
                                            mems=mems)
    return torch.cat((context_tokens, tokens), dim=1), mems

解释

  • 核心计算 next_token_logits
  • next_token_logits, *mems = model(context_tokens or last_token, position_ids, attention_mask, *mems)
  • 其中postion_id 与attention_mask与集束搜索中的lastbeam有关
  • 如果lastbeam<=1就要先除以TEMP在通过top_k_logits投票算logits
  • 经过softmax后送入beam_outputs(beam_scorer.process)
  • tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
  • 输出的是一维的拼接token张量和mems

top_k_logits

def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    # This function has been mostly taken from huggingface conversational ai code at
    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313

    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # convert to 1D
        logits = logits.view(logits.size()[1]).contiguous()
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
        # going back to 2D
        logits = logits.view(1, -1).contiguous()

    return logits

解释:

  • top_k:移除低于最近last token的10个以外的所有
  • top_p: 2D 拍成1D ,移除累积概率高于阈值的标记,变回2D

已经做过的四个项目角度为例(在写了,先鸽了)

  • GLM-poem (古诗生成 iprompt)
  • LongText-GLM (长文本训练)
  • GLM-lycr (现代诗prompt)
  • GLM-SEO (SEOs2s微调)