@dataclass
class Template:
    format_user: "Formatter"
    format_assistant: "Formatter"
    format_system: "Formatter"
    format_function: "Formatter"
    format_observation: "Formatter"
    format_tools: "Formatter"
    format_separator: "Formatter"
    default_system: str
    stop_words: List[str]
    efficient_eos: bool
    replace_eos: bool
    force_system: bool

    def encode_oneturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Tuple[List[int], List[int]]:
        r"""
        Returns a single pair of token ids representing prompt and response respectively.
        """
        encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
        prompt_ids = []
        for query_ids, resp_ids in encoded_pairs[:-1]:
            prompt_ids += query_ids + resp_ids
        prompt_ids = prompt_ids + encoded_pairs[-1][0]
        answer_ids = encoded_pairs[-1][1]
        return prompt_ids, answer_ids

    def encode_multiturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Returns multiple pairs of token ids representing prompts and responses respectively.
        """
        return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)

    def _encode(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: str,
        tools: str,
        cutoff_len: int,
        reserved_label_len: int,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Encodes formatted inputs to pairs of token ids.
        Turn 0: system + query        resp
        Turn t: sep + query           resp
        """
        system = system or self.default_system
        encoded_messages = []
        for i, message in enumerate(messages):
            elements = []
            if i == 0 and (system or tools or self.force_system):
                tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
                elements += self.format_system.apply(content=(system + tool_text))
            elif i > 0 and i % 2 == 0:
                elements += self.format_separator.apply()

            if message["role"] == Role.USER.value:
                elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
            elif message["role"] == Role.ASSISTANT.value:
                elements += self.format_assistant.apply(content=message["content"])
            elif message["role"] == Role.OBSERVATION.value:
                elements += self.format_observation.apply(content=message["content"])
            elif message["role"] == Role.FUNCTION.value:
                elements += self.format_function.apply(content=message["content"])
            else:
                raise NotImplementedError("Unexpected role: {}".format(message["role"]))

            encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))

        return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)

这段代码定义了一个名为 Template 的数据类,并包含了多个方法用于将消息编码为一系列的 token id。我们逐行解释这段代码:

数据类定义

python

复制

from dataclasses import dataclass
from typing import List, Dict, Optional, Sequence, Tuple

@dataclass
class Template:
  • 使用 @dataclass 装饰器定义了一个 Template 类dataclass 会自动生成初始化方法和其他方法。

类属性定义

python

复制

    format_user: "Formatter"
    format_assistant: "Formatter"
    format_system: "Formatter"
    format_function: "Formatter"
    format_observation: "Formatter"
    format_tools: "Formatter"
    format_separator: "Formatter"
    default_system: str
    stop_words: List[str]
    efficient_eos: bool
    replace_eos: bool
    force_system: bool
  • 定义了多个类属性,这些属性主要是不同角色的格式化器(Formatter),以及一些控制编码行为的布尔值和字符串。

encode_oneturn 方法

python

复制

    def encode_oneturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Tuple[List[int], List[int]]:
        r"""
        Returns a single pair of token ids representing prompt and response respectively.
        """
  • 定义了一个用于编码单轮对话的 encode_oneturn 方法。

python

复制

        encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
  • 调用私有方法 _encode 进行实际的编码工作,返回编码后的消息对。

python

复制

        prompt_ids = []
        for query_ids, resp_ids in encoded_pairs[:-1]:
            prompt_ids += query_ids + resp_ids
        prompt_ids = prompt_ids + encoded_pairs[-1][0]
        answer_ids = encoded_pairs[-1][1]
        return prompt_ids, answer_ids
  • 将所有轮次的 query_ids 和 resp_ids 拼接成一个完整的 prompt_ids,并返回最后一轮的 response_ids 作为 answer_ids

encode_multiturn 方法

python

复制

    def encode_multiturn(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
        cutoff_len: int = 1_000_000,
        reserved_label_len: int = 1,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Returns multiple pairs of token ids representing prompts and responses respectively.
        """
  • 定义了一个用于编码多轮对话的 encode_multiturn 方法。

python

复制

        return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
  • 直接调用 _encode 方法并返回其结果。

_encode 私有方法

python

复制

    def _encode(
        self,
        tokenizer: "PreTrainedTokenizer",
        messages: List[Dict[str, str]],
        system: str,
        tools: str,
        cutoff_len: int,
        reserved_label_len: int,
    ) -> Sequence[Tuple[List[int], List[int]]]:
        r"""
        Encodes formatted inputs to pairs of token ids.
        Turn 0: system + query        resp
        Turn t: sep + query           resp
        """
  • 定义了一个私有方法 _encode,用于实际的编码过程。

python

复制

        system = system or self.default_system
  • 如果 system 参数为空,使用类属性 default_system

python

复制

        encoded_messages = []
        for i, message in enumerate(messages):
            elements = []
            if i == 0 and (system or tools or self.force_system):
                tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
                elements += self.format_system.apply(content=(system + tool_text))
            elif i > 0 and i % 2 == 0:
                elements += self.format_separator.apply()
  • 循环遍历消息列表,根据消息位置和内容组装元素列表。
  • 对于第一条消息,如果有 system 或 tools 或者 force_system 为真,则应用 format_system 和(如果有的话) format_tools

从 _encode 方法的消息处理循环部分

python

复制

            if message["role"] == Role.USER.value:
                elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
            elif message["role"] == Role.ASSISTANT.value:
                elements += self.format_assistant.apply(content=message["content"])
            elif message["role"] == Role.OBSERVATION.value:
                elements += self.format_observation.apply(content=message["content"])
            elif message["role"] == Role.FUNCTION.value:
                elements += self.format_function.apply(content=message["content"])
            else:
                raise NotImplementedError("Unexpected role: {}".format(message["role"]))
  • 根据消息的角色(role)应用不同的格式化器:
    • 如果角色是用户(Role.USER),使用 format_user
    • 如果角色是助手(Role.ASSISTANT),使用 format_assistant
    • 如果角色是观察(Role.OBSERVATION),使用 format_observation
    • 如果角色是功能(Role.FUNCTION),使用 format_function
    • 如果遇到未实现的角色,抛出 NotImplementedError

python

复制

            encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
  • 将格式化后的元素转换为 token id 并添加到 encoded_messages 列表中。

python

复制

        return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
  • 调用 _make_pairs 方法,将编码后的消息组装成 token id 对并返回。

其他私有方法(假设)

以下是一些假设的私有方法,用于完成编码和转换的具体操作:

python

复制

    def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: List[str]) -> List[int]:
        """
        Converts a list of formatted elements into token ids using the tokenizer.
        """
        token_ids = []
        for element in elements:
            token_ids.extend(tokenizer.encode(element, add_special_tokens=False))
        return token_ids

    def _make_pairs(self, encoded_messages: List[List[int]], cutoff_len: int, reserved_label_len: int) -> Sequence[Tuple[List[int], List[int]]]:
        """
        Creates pairs of prompt and response token ids with cutoff and reserved lengths.
        """
        pairs = []
        for i in range(0, len(encoded_messages), 2):
            query_ids = encoded_messages[i]
            resp_ids = encoded_messages[i + 1] if i + 1 < len(encoded_messages) else []
            if len(query_ids) + len(resp_ids) > cutoff_len:
                query_ids = query_ids[:cutoff_len - len(resp_ids)]
            resp_ids = resp_ids[:reserved_label_len]
            pairs.append((query_ids, resp_ids))
        return pairs
  • _convert_elements_to_ids 方法:将格式化的元素列表转换为 token id 列表。
  • _make_pairs 方法:根据 cutoff_len 和 reserved_label_len 将编码后的消息组合成 token id 对。

总结

这段代码定义了一个 Template 类,通过不同的格式化器和编码方法,将一系列消息转换为 token id。主要方法包括 encode_oneturn 和 encode_multiturn,分别用于编码单轮和多轮对话。私有方法 _encode 负责实际的编码过程,利用 _convert_elements_to_ids 和 _make_pairs 方法完成元素转换和对话对的创建。

def get_template_and_fix_tokenizer(
    tokenizer: "PreTrainedTokenizer",
    name: Optional[str] = None,
) -> Template:
    if name is None:
        template = templates["vanilla"]  # placeholder
    else:
        template = templates.get(name, None)
        if template is None:
            raise ValueError("Template {} does not exist.".format(name))

    stop_words = template.stop_words
    if template.replace_eos:
        if not stop_words:
            raise ValueError("Stop words are required to replace the EOS token.")

        _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
        stop_words = stop_words[1:]

    if tokenizer.eos_token_id is None:
        _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        logger.info("Add pad token: {}".format(tokenizer.pad_token))

    if stop_words:
        num_added_tokens = tokenizer.add_special_tokens(
            dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
        )
        logger.info("Add {} to stop words.".format(",".join(stop_words)))
        if num_added_tokens > 0:
            logger.warning("New tokens have been added, make sure `resize_vocab` is True.")

    try:
        tokenizer.chat_template = _get_jinja_template(template, tokenizer)
    except ValueError:
        logger.info("Cannot add this chat template to tokenizer.")

    return template


_register_template(
    name="alpaca",
    format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
    format_separator=EmptyFormatter(slots=["\n\n"]),
    default_system=(
        "Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
    ),
)

逐行解释这段代码。

get_template_and_fix_tokenizer 函数

python

复制

def get_template_and_fix_tokenizer(
    tokenizer: "PreTrainedTokenizer",
    name: Optional[str] = None,
) -> Template:
  • 定义一个函数 get_template_and_fix_tokenizer,用于获取模板并修正 tokenizer。参数包括 tokenizer 和可选的模板名称 name

python

复制

    if name is None:
        template = templates["vanilla"]  # placeholder
    else:
        template = templates.get(name, None)
        if template is None:
            raise ValueError("Template {} does not exist.".format(name))
  • 如果 name 参数为空,默认使用 vanilla 模板
  • 否则,尝试获取指定名称的模板。如果模板不存在,抛出 ValueError

python

复制

    stop_words = template.stop_words
    if template.replace_eos:
        if not stop_words:
            raise ValueError("Stop words are required to replace the EOS token.")

        _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
        stop_words = stop_words[1:]
  • 获取模板中的 stop_words
  • 如果模板要求替换 EOS(End Of Sentence)标记,但 stop_words 为空,抛出 ValueError
  • 否则,用 stop_words 中的第一个词替换 EOS 标记,并移除已使用的词。

解释 get_template_and_fix_tokenizer 函数:

python

复制

    if tokenizer.eos_token_id is None:
        _add_or_replace_eos_token(tokenizer, eos_token="
07-04 10:12