Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ring Attention] fix the 2d ring attn when using multiple machine #6071

Merged
merged 26 commits into from
Oct 15, 2024

Conversation

wangbluo
Copy link
Contributor

@wangbluo wangbluo commented Sep 25, 2024

🚨 Issue number

fixed #6017

📝 What does this PR do?

The double_ring_groups need to consider the tp groups as the tp axis is the first axis.
And the ranks in double_ring_groups need to transformered into global ranks.
refer: https://arxiv.org/pdf/2406.18485
image
image

Problem Transformation:

Given a consecutive array from 0 to world_size - 1 and a number num_rings, divide the array into num_rings subarrays by evenly dividing world_size by num_rings.

Inner Ring Group:

For each subarray, that is a single num_ring, take one number every tp_size elements. The extracted numbers are grouped together every inner_ring_size to form a new element. Continue this process until reaching world_size. The collection of all these new elements forms the inner ring group array.

Inter Ring Group:

From each subarray, take the numbers at the same index to form a new element. These new elements are combined into a new array, which is the inter ring group.

Test code:

            num_ring_size = world_size//num_rings
            ranks = []
            end = num_ring_size
            for i in range(num_rings):
                start = i * num_ring_size
                end = (i+1) * num_ring_size
                for idx in range(start, end):
                    inner_rank = []
                    for k in range(inner_ring_size):
                        current_num = idx + k * tp_size
                        if current_num >= end:  
                            break
                        inner_rank.append(current_num)
                    if len(inner_rank) == inner_ring_size and inner_rank not in ranks:
                        ranks.append(inner_rank)
            print("inner ranks:",ranks) 
               
            inter_ranks = []
            for i in range(num_ring_size):
                inter_rank = [i + j * num_ring_size for j in range(num_rings)]
                inter_ranks.append(inter_rank)
            print("inter ranks:",inter_ranks)   

example1:
world size = 8, tp 2, sp 4, inner_ring_size = 2

results:
image

example2:
world size = 32, tp 8, sp 4, inner_ring_size = 2,
2 nodes and 16 gpus in one node

results:
image

example3:
world_size = 32 tp_size = 4 sp_size = 8,inner_ring_size = 4

results:
image

example 4:
world_size = 24 tp_size = 2 sp_size = 6,inner_ring_size = 3

results:
image

@wangbluo wangbluo requested a review from a team as a code owner September 25, 2024 10:57
@wangbluo wangbluo closed this Sep 26, 2024
@wangbluo wangbluo deleted the ring_attention branch September 26, 2024 10:06
@wangbluo wangbluo restored the ring_attention branch September 26, 2024 10:07
@wangbluo wangbluo reopened this Sep 26, 2024
@Edenzzzz Edenzzzz changed the title [Fix] fix the 2d ring attn when using multiple machine [Fix] fix the 2d ring attention when using multiple machine Oct 9, 2024
@Edenzzzz Edenzzzz changed the title [Fix] fix the 2d ring attention when using multiple machine [Ring Attention] fix the 2d ring attn when using multiple machine Oct 9, 2024
@wangbluo wangbluo merged commit dcd41d0 into hpcaitech:main Oct 15, 2024
4 checks passed
@wangbluo wangbluo deleted the ring_attention branch October 15, 2024 07:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants