-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for context parallelism #174
base: v2
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any reason why this wouldn't work, but some parts might be less simple/natural than they can be. Approving so you can be unblocked for now.
for b in attn_buffers.values(): | ||
mark_dynamic(b, 0) | ||
|
||
with self._model_forward_context(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like PP doesn't support CP yet. Maybe cause a config error in that case.
@@ -560,9 +592,14 @@ def train_batch(self, batch: Dict[str, Any], dry_run: bool = False): | |||
|
|||
# Train one micro-batch at a time. | |||
for micro_batch_idx, micro_batch in enumerate(micro_batches): | |||
with self._train_microbatch_context(micro_batch_idx, num_micro_batches): | |||
attn_buffers = self.model.get_attn_buffers( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a need to pass the attention buffers back through forward
in this manner? I thought the context parallelism would modify them in-place. I would prefer not doing this if it's not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API needs some improvement but there's a couple reasons I chose to pass these buffers through the forward method:
- This allows the
TransformerTrainModule
to mark their sizes as dynamic fortorch.compile()
. This is irrelevant to context parallelism but it's a good thing to do in general. Without this I think we were getting some unnecessary re-compilations during in-loop evals. - It guarantees we don't create duplicate versions of these buffers, which could happen if the
BufferCache
owned by eachAttention
module was a different instance. In our usual constructors we make sure the cache is shared, but if someone instantiated their model differently, the cache might not be shared.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright refactored this so that the buffers are managed completely by the model.
Co-authored-by: Shane A <[email protected]>
To enable, run with
--train_module.cp_config='{degree: 2}'
(adjust the degree as needed). Probably don't try this with TP at the same time for now.