希望dump出pytorch每个算子的输入输出,但pytorch普通的hook机制只能拦截module.以下提供一种方法可以拦截torch.add,torch.Tensor.add这类算子.原理是通过模板替换,劫持torch和torch.Tensor中的算子.遍历next_functions调用register_hook拦截backward.
一.代码
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading
import base64
from jinja2 import Template
device="cuda"
class Attention(nn.Module):
def __init__(self,max_seq_len,head_dim,flash):
super().__init__()
self.flash = flash #hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.dropout=0
self.attn_dropout = nn.Dropout(self.dropout)
self.head_dim=head_dim
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)
mask = torch.triu(mask, diagonal=1).half().to(device)
self.register_buffer("mask", mask)
def forward(
self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
_xk=xk.clone()
t=_xk.transpose(2, 3)
scores = torch.matmul(xq,t)
scores = scores/math.sqrt(self.head_dim)
a=self.mask[:, :, :seqlen, :seqlen]
scores = torch.add(scores,a)
scores = F.softmax(scores.float(), dim=-1)
scores = scores.type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
return output
lock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):
if isinstance(args,torch.Tensor):
print(name,index,args.shape)
global gindex
lock.acquire()
torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))
gindex+=1
lock.release()
if isinstance(args,tuple):
for idx,x in enumerate(args):
save_tensor(name,x,index+idx)
op_template=Template('''
native1_{{new_name}}=getattr(torch.Tensor,'{{name}}')
def {{new_name}}(*args, **kwargs):
save_tensor("{{name}}-input",args)
global native1_{{new_name}}
ret=native1_{{new_name}}(*args, **kwargs)
save_tensor("{{name}}-output",ret)
return ret
setattr(torch.Tensor, '{{name}}', {{new_name}})
''')
for op in dir(torch.Tensor):
if op in ["__iter__","shape","dim","unbind","normal_","data",
"item","numel","save","has_names","data_ptr","untyped_storage",
"storage_offset","size","stride","triu","half","is_floating_point",
"to","ones","randint","ones_like"]:
continue
if getattr(torch.Tensor,op).__class__.__name__ not in ["method_descriptor"]:
continue
new_name=base64.b64encode(str(f"torch.Tensor.{op}").encode('utf-8')).decode("utf-8").replace("=","")
exec(op_template.render(name=op,new_name=new_name))
op_template=Template('''
native2_{{new_name}}=getattr(torch,'{{name}}')
def {{new_name}}(*args, **kwargs):
save_tensor("{{name}}-input",args)
global native2_{{new_name}}
ret=native2_{{new_name}}(*args, **kwargs)
save_tensor("{{name}}-output",ret)
return ret
setattr(torch, '{{name}}', {{new_name}})
''')
for op in dir(torch):
if op in ["is_grad_enabled","__iter__","save","has_names","data_ptr",
"untyped_storage","storage_offset","size","stride","triu",
"is_floating_point","to","ones","randint","full","reshape","ones_like"]:
continue
if getattr(torch,op).__class__.__name__ not in ["builtin_function_or_method"]:
continue
new_name=base64.b64encode(str(f"torch.{op}").encode('utf-8')).decode("utf-8").replace("=","")
exec(op_template.render(name=op,new_name=new_name))
def hook_backwards(loss, cached):
if loss is None:
return
def posthook(*args,**kwargs):
save_tensor(loss.__class__.__name__,args)
def prehook(*args,**kwargs):
pass
loss.register_prehook(prehook)
loss.register_hook(posthook)
cached.add(loss)
for _, child in enumerate(loss.next_functions):
if child[0] not in cached:
hook_backwards(child[0],cached)
def main(flash,bs, n_local_heads, seqlen, head_dim):
torch.random.manual_seed(1)
q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
q.data.normal_(0, 0.1)
k.data.normal_(0, 0.1)
v.data.normal_(0, 0.1)
q=Variable(q, requires_grad=True).to(device)
k=Variable(k, requires_grad=True).to(device)
v=Variable(v, requires_grad=True).to(device)
gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)
loss_func=nn.CrossEntropyLoss().to(device)
model=Attention(seqlen,head_dim,flash).half().to(device)
optim = torch.optim.SGD([q,k,v], lr=1.1)
for i in range(1):
output = model(q,k,v)
loss=loss_func(output.reshape(-1,head_dim),gt)
hook_backwards(loss.grad_fn, cached=set())
loss.backward()
optim.step()
print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))
bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)
二.输出
reshape-input 0 torch.Size([32768, 1])
reshape-output 0 torch.Size([32768])
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 64, 512])
matmul-input 0 torch.Size([8, 8, 512, 64])
matmul-input 1 torch.Size([8, 8, 64, 512])
matmul-output 0 torch.Size([8, 8, 512, 512])
__truediv__-input 0 torch.Size([8, 8, 512, 512])
__truediv__-output 0 torch.Size([8, 8, 512, 512])
add-input 0 torch.Size([8, 8, 512, 512])
add-input 1 torch.Size([1, 1, 512, 512])
add-output 0 torch.Size([8, 8, 512, 512])
float-input 0 torch.Size([8, 8, 512, 512])
float-output 0 torch.Size([8, 8, 512, 512])
softmax-input 0 torch.Size([8, 8, 512, 512])
softmax-output 0 torch.Size([8, 8, 512, 512])
type_as-input 0 torch.Size([8, 8, 512, 512])
type_as-input 1 torch.Size([8, 8, 512, 64])
type_as-output 0 torch.Size([8, 8, 512, 512])
matmul-input 0 torch.Size([8, 8, 512, 512])
matmul-input 1 torch.Size([8, 8, 512, 64])
matmul-output 0 torch.Size([8, 8, 512, 64])
reshape-input 0 torch.Size([8, 8, 512, 64])
reshape-output 0 torch.Size([32768, 64])
NllLossBackward0 0 torch.Size([32768, 64])
NllLossBackward0 1 torch.Size([])
LogSoftmaxBackward0 0 torch.Size([32768, 64])
LogSoftmaxBackward0 1 torch.Size([32768, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([32768, 64])
UnsafeViewBackward0 0 torch.Size([64, 512, 64])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 64])
BmmBackward0 0 torch.Size([64, 512, 512])
BmmBackward0 1 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
ViewBackward0 0 torch.Size([8, 8, 512, 512])
ViewBackward0 1 torch.Size([64, 512, 512])
ExpandBackward0 0 torch.Size([8, 8, 512, 512])
ExpandBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 0 torch.Size([8, 8, 512, 512])
SoftmaxBackward0 1 torch.Size([8, 8, 512, 512])
ToCopyBackward0 0 torch.Size([8, 8, 512, 512])
ToCopyBackward0 1 torch.Size([8, 8, 512, 512])
AddBackward0 0 torch.Size([8, 8, 512, 512])
AddBackward0 1 torch.Size([8, 8, 512, 512])
DivBackward0 0 torch.Size([8, 8, 512, 512])
DivBackward0 1 torch.Size([8, 8, 512, 512])
UnsafeViewBackward0 0 torch.Size([64, 512, 512])
UnsafeViewBackward0 1 torch.Size([8, 8, 512, 512])
BmmBackward0 0 torch.Size([64, 512, 64])
BmmBackward0 1 torch.Size([64, 64, 512])
BmmBackward0 1 torch.Size([64, 512, 512])
ReshapeAliasBackward0 0 torch.Size([8, 8, 64, 512])
ReshapeAliasBackward0 1 torch.Size([64, 64, 512])
ExpandBackward0 0 torch.Size([8, 8, 64, 512])
ExpandBackward0 1 torch.Size([8, 8, 64, 512])
ViewBackward0 0 torch.Size([8, 8, 512, 64])
ViewBackward0 1 torch.Size([64, 512, 64])
ExpandBackward0 0 torch.Size([8, 8, 512, 64])
ExpandBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
TransposeBackward0 0 torch.Size([8, 8, 512, 64])
TransposeBackward0 1 torch.Size([8, 8, 64, 512])
CloneBackward0 0 torch.Size([8, 8, 512, 64])
CloneBackward0 1 torch.Size([8, 8, 512, 64])
AccumulateGrad 1 torch.Size([8, 8, 512, 64])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
sum-input 0 torch.Size([8, 8, 512, 64])
sum-output 0 torch.Size([])
45.56250,-12.76562,121.68750,4.16016