Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

causal_conv1d_fn fails when number of channels is 2^16 or more #32

Open
KeAWang opened this issue Sep 5, 2024 · 0 comments
Open

causal_conv1d_fn fails when number of channels is 2^16 or more #32

KeAWang opened this issue Sep 5, 2024 · 0 comments

Comments

@KeAWang
Copy link

KeAWang commented Sep 5, 2024

Once we use 2^16 or more channels, we get the following error:

RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Minimal reproducible code:

import torch
from causal_conv1d import causal_conv1d_fn

device = "cuda"
b,  l = 1, 2**2
k = 2


# this works
c = 2**16 - 1
x = torch.randn(b, c, l, dtype=torch.float32, device=device)  # batch, channel, seq_len
weight = torch.randn(c, k, dtype=torch.float32, device=device)
y = causal_conv1d_fn(x, weight)

# this fails (c >= 2**16)
c = 2**16
x = torch.randn(b, c, l, dtype=torch.float32, device=device)  # batch, channel, seq_len
weight = torch.randn(c, k, dtype=torch.float32, device=device)
y = causal_conv1d_fn(x, weight)

This might be because the implementation parallelizes the blocks across the batch and channel axes, but CUDA blocks only go up to 2^16-1 along the y axis:

const int channel_id = blockIdx.y;

Flipping the .x and .y should fix the issue (at least for the non-channellast implementation).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant