@dataclass
class Template:
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
system = system or self.default_system
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
这段代码定义了一个名为 Template
的数据类,并包含了多个方法用于将消息编码为一系列的 token id。我们逐行解释这段代码:
数据类定义
python
复制
from dataclasses import dataclass
from typing import List, Dict, Optional, Sequence, Tuple
@dataclass
class Template:
- 使用
@dataclass
装饰器定义了一个Template
类。dataclass
会自动生成初始化方法和其他方法。
类属性定义
python
复制
format_user: "Formatter"
format_assistant: "Formatter"
format_system: "Formatter"
format_function: "Formatter"
format_observation: "Formatter"
format_tools: "Formatter"
format_separator: "Formatter"
default_system: str
stop_words: List[str]
efficient_eos: bool
replace_eos: bool
force_system: bool
- 定义了多个类属性,这些属性主要是不同角色的格式化器(
Formatter
),以及一些控制编码行为的布尔值和字符串。
encode_oneturn
方法
python
复制
def encode_oneturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
- 定义了一个用于编码单轮对话的
encode_oneturn
方法。
python
复制
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
- 调用私有方法
_encode
进行实际的编码工作,返回编码后的消息对。
python
复制
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids += query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
answer_ids = encoded_pairs[-1][1]
return prompt_ids, answer_ids
- 将所有轮次的
query_ids
和resp_ids
拼接成一个完整的prompt_ids
,并返回最后一轮的response_ids
作为answer_ids
。
encode_multiturn
方法
python
复制
def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: int = 1_000_000,
reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
- 定义了一个用于编码多轮对话的
encode_multiturn
方法。
python
复制
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
- 直接调用
_encode
方法并返回其结果。
_encode
私有方法
python
复制
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: system + query resp
Turn t: sep + query resp
"""
- 定义了一个私有方法
_encode
,用于实际的编码过程。
python
复制
system = system or self.default_system
- 如果
system
参数为空,使用类属性default_system
。
python
复制
encoded_messages = []
for i, message in enumerate(messages):
elements = []
if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
elements += self.format_separator.apply()
- 循环遍历消息列表,根据消息位置和内容组装元素列表。
- 对于第一条消息,如果有
system
或tools
或者force_system
为真,则应用format_system
和(如果有的话)format_tools
。
从 _encode
方法的消息处理循环部分:
python
复制
if message["role"] == Role.USER.value:
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT.value:
elements += self.format_assistant.apply(content=message["content"])
elif message["role"] == Role.OBSERVATION.value:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION.value:
elements += self.format_function.apply(content=message["content"])
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
- 根据消息的角色(
role
)应用不同的格式化器:- 如果角色是用户(
Role.USER
),使用format_user
。 - 如果角色是助手(
Role.ASSISTANT
),使用format_assistant
。 - 如果角色是观察(
Role.OBSERVATION
),使用format_observation
。 - 如果角色是功能(
Role.FUNCTION
),使用format_function
。 - 如果遇到未实现的角色,抛出
NotImplementedError
。
- 如果角色是用户(
python
复制
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
- 将格式化后的元素转换为 token id 并添加到
encoded_messages
列表中。
python
复制
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
- 调用
_make_pairs
方法,将编码后的消息组装成 token id 对并返回。
其他私有方法(假设)
以下是一些假设的私有方法,用于完成编码和转换的具体操作:
python
复制
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: List[str]) -> List[int]:
"""
Converts a list of formatted elements into token ids using the tokenizer.
"""
token_ids = []
for element in elements:
token_ids.extend(tokenizer.encode(element, add_special_tokens=False))
return token_ids
def _make_pairs(self, encoded_messages: List[List[int]], cutoff_len: int, reserved_label_len: int) -> Sequence[Tuple[List[int], List[int]]]:
"""
Creates pairs of prompt and response token ids with cutoff and reserved lengths.
"""
pairs = []
for i in range(0, len(encoded_messages), 2):
query_ids = encoded_messages[i]
resp_ids = encoded_messages[i + 1] if i + 1 < len(encoded_messages) else []
if len(query_ids) + len(resp_ids) > cutoff_len:
query_ids = query_ids[:cutoff_len - len(resp_ids)]
resp_ids = resp_ids[:reserved_label_len]
pairs.append((query_ids, resp_ids))
return pairs
_convert_elements_to_ids
方法:将格式化的元素列表转换为 token id 列表。_make_pairs
方法:根据cutoff_len
和reserved_label_len
将编码后的消息组合成 token id 对。
总结
这段代码定义了一个 Template
类,通过不同的格式化器和编码方法,将一系列消息转换为 token id。主要方法包括 encode_oneturn
和 encode_multiturn
,分别用于编码单轮和多轮对话。私有方法 _encode
负责实际的编码过程,利用 _convert_elements_to_ids
和 _make_pairs
方法完成元素转换和对话对的创建。
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError:
logger.info("Cannot add this chat template to tokenizer.")
return template
_register_template(
name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
default_system=(
"Below is an instruction that describes a task. " "Write a response that appropriately completes the request."
),
)
逐行解释这段代码。
get_template_and_fix_tokenizer
函数
python
复制
def get_template_and_fix_tokenizer(
tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
- 定义一个函数
get_template_and_fix_tokenizer
,用于获取模板并修正 tokenizer。参数包括tokenizer
和可选的模板名称name
。
python
复制
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if template is None:
raise ValueError("Template {} does not exist.".format(name))
- 如果
name
参数为空,默认使用vanilla
模板。 - 否则,尝试获取指定名称的模板。如果模板不存在,抛出
ValueError
。
python
复制
stop_words = template.stop_words
if template.replace_eos:
if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.")
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:]
- 获取模板中的
stop_words
。 - 如果模板要求替换 EOS(End Of Sentence)标记,但
stop_words
为空,抛出ValueError
。 - 否则,用
stop_words
中的第一个词替换 EOS 标记,并移除已使用的词。
解释 get_template_and_fix_tokenizer
函数:
python
复制
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="