Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for context parallelism #174

Open
wants to merge 21 commits into
base: v2
Choose a base branch
from
Open

Add support for context parallelism #174

wants to merge 21 commits into from

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Feb 21, 2025

To enable, run with --train_module.cp_config='{degree: 2}' (adjust the degree as needed). Probably don't try this with TP at the same time for now.

Copy link
Contributor

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason why this wouldn't work, but some parts might be less simple/natural than they can be. Approving so you can be unblocked for now.

for b in attn_buffers.values():
mark_dynamic(b, 0)

with self._model_forward_context():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like PP doesn't support CP yet. Maybe cause a config error in that case.

@@ -560,9 +592,14 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False):

# Train one micro-batch at a time.
for micro_batch_idx, micro_batch in enumerate(micro_batches):
with self._train_microbatch_context(micro_batch_idx, num_micro_batches):
attn_buffers = self.model.get_attn_buffers(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to pass the attention buffers back through forward in this manner? I thought the context parallelism would modify them in-place. I would prefer not doing this if it's not needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API needs some improvement but there's a couple reasons I chose to pass these buffers through the forward method:

  1. This allows the TransformerTrainModule to mark their sizes as dynamic for torch.compile(). This is irrelevant to context parallelism but it's a good thing to do in general. Without this I think we were getting some unnecessary re-compilations during in-loop evals.
  2. It guarantees we don't create duplicate versions of these buffers, which could happen if the BufferCache owned by each Attention module was a different instance. In our usual constructors we make sure the cache is shared, but if someone instantiated their model differently, the cache might not be shared.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright refactored this so that the buffers are managed completely by the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants