Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the Error of q, k, and v states must have the same dtype when using flash attention forward. #15

Merged
merged 2 commits into from
May 30, 2024

Conversation

jacklanda
Copy link
Contributor

Step for Error Reproduction

Once we enable the argument use_cache=True in the Hugging Face generation_config something like the following:

generation_config = GenerationConfig(
            bos_token_id=128000,
            eos_token_id=128001,
            pad_token_id=self.tokenizer.pad_token_id,
            use_cache=True,
)

You can simply pass the generation config with use_cache=True to model and make it forward as usual, and I believe you'll get the error I met.

Error Message

I will get the following error message: "query and key must have the same dtype". And I found that this is due to the misalignment between the dtype of query_state and value_state.

Potential Solution

I guess there exists potential casting for q, k, and v states. Hence, I try to cast back the dtype of q, k, and v to the same target dtype.

Reference

The committed code I wrote strongly refers to the existing implementation in mergoo.

@jacklanda jacklanda changed the title [Bug] Query and key must have the same dtype when use flash attention forward [Fix] Fix the Error of q, k, and v states must have the same dtype when using flash attention forward. May 28, 2024
mergoo/models/modeling_llama.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@jacklanda jacklanda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submit new commit: add dtype checking for q, k, and v states before auto casting.

mergoo/models/modeling_llama.py Outdated Show resolved Hide resolved
@jacklanda jacklanda requested a review from gitsailor5 May 29, 2024 03:07
Copy link
Contributor

@gitsailor5 gitsailor5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@gitsailor5 gitsailor5 merged commit c73a047 into Leeroo-AI:main May 30, 2024
1 check passed
@jacklanda jacklanda deleted the fix-use-cache-error branch May 30, 2024 15:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants