Skip to content

fixed the bug described in Issue#30 by rewrite backward kernels#33

Open
kawabata-tomoko wants to merge 1 commit intoHazyResearch:mainfrom
kawabata-tomoko:even_padding
Open

fixed the bug described in Issue#30 by rewrite backward kernels#33
kawabata-tomoko wants to merge 1 commit intoHazyResearch:mainfrom
kawabata-tomoko:even_padding

Conversation

@kawabata-tomoko
Copy link

modified: csrc/flashfftconv/conv1d/conv1d_bwd_cuda_bhl.cu
modified: csrc/flashfftconv/conv1d/conv1d_bwd_cuda_blh.cu

And could passed this sample script in Issue#30:

import torch
import torch.nn as nn
import torch.optim as optim
from flashfftconv import FlashDepthWiseConv1d
B=4
L=26000
d=512
k=3
padding=k-1
dtype=torch.bfloat16
device="cuda:4"
# set up PyTorch equivalent to get the weights
# in_channels = out_channels, and kernel size must be odd
x=torch.randn((B,L,d),device=device,dtype=dtype)
trans = nn.Linear(d,d).to(device=device,dtype=dtype)
conv1d_torch = nn.Conv1d(
    in_channels = d,
    out_channels = d,
    kernel_size = k,
    groups = d,
    padding = padding,
    dtype = dtype,
    device=device
)

flash_conv1d = FlashDepthWiseConv1d(
    channels = d,
    kernel_size=k,
    padding=padding,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    is_bhl=True,#or False
    dtype = dtype # this should be the dtype of the weights
).to(device=device)

x_input=trans(x).transpose(-1,-2).contiguous()
x_input_flash = x_input.detach().clone().requires_grad_(True)
x_input_torch = x_input.detach().clone().requires_grad_(True)

out_torch = conv1d_torch(x_input_torch)
out_flash = flash_conv1d(x_input_flash)
# out_flash = flash_conv1d(x_input_flash.transpose(-1, -2).contiguous())

out_flash.sum().backward()
out_torch.sum().backward()

print((x_input_flash.grad[0]-x_input_torch.grad[0]).sum())

	modified:   csrc/flashfftconv/conv1d/conv1d_bwd_cuda_bhl.cu
	modified:   csrc/flashfftconv/conv1d/conv1d_bwd_cuda_blh.cu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant