-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add sentence-similarity
support to the ArgillaTrainer
#3739
feat: add sentence-similarity
support to the ArgillaTrainer
#3739
Conversation
|
||
if "model" in self.model_kwargs: | ||
self._model = self.model_kwargs["model"] | ||
# self._model = self.model_kwargs["model_name_or_path"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out? Do we use model_name_or_path
or model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model
is the one used, I forgot to remove the model_name_or_path
. Internally SentenceTransformer
uses model_name_or_path
as keyword argument, while CrossEncoder
uses model_name
.
I will take a closer look, the classes aren't completely equal and maybe I lost more arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @plaguss it looks good already but I left some small initial remarks.
Also, I had some questions:
Do you see this integration potentially also working with an out-of-box solution using Fields and Questions as defined for TextClassification? If so, why, else why not?
What is, according to you, the most interesting Hugging Face nlp task to add to the trainer and why?
src/argilla/client/feedback/training/frameworks/sentence_transformers.py
Outdated
Show resolved
Hide resolved
self._cross_encoder = cross_encoder | ||
|
||
if model is None: | ||
if not self._cross_encoder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we could add a warning here about the usage of a pre-defined model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does something like this work?:
self._logger.warning(f"No model was selected, using pre-defined `{'Cross-Encoder' if self._cross_encoder else 'Bi-Encoder'}` with `{self._model}`.")
src/argilla/client/feedback/training/frameworks/sentence_transformers.py
Show resolved
Hide resolved
https://huggingface.co/blog/how-to-train-sentence-transformers#how-to-prepare-your-dataset-for-training-a-sentence-transformers-model | ||
""" | ||
|
||
format: Union[Dict[str, Union[float, int]], Dict[str, str]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also want to allow for returning a list of samples because each record might contain multiple comparisons.
_formatting_func_return_types = TrainingTaskForSentenceSimilarityFormat | ||
formatting_func: Callable[[Dict[str, Any]], Union[None, Tuple[str, str, str], Iterator[Tuple[str, str, str]]]] | ||
|
||
def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be changed a bit to allow for returning a list of samples too from the formatting_func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commit should address both this point and the previous one. The functions from the tests have been modified to generate data following test_trl.py
.
def _train_test_split(self, data: List[dict], train_size: float, seed: int) -> Tuple[List[dict], List[dict]]: | ||
from sklearn.model_selection import train_test_split | ||
|
||
return train_test_split( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we consider stratifying this split too? I can imagine the model might over-under estimate when not doing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will add it for the datasets that include integer labels.
if "model_name_or_path" in self.model_kwargs: | ||
self.model_kwargs.pop("model_name_or_path") | ||
if "model_name" in self.model_kwargs: | ||
self.model_kwargs.pop("model_name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we also using model_name
and model_name_or_path
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we grab the default args it is done for SentenceTransformer
and CrossEncoder
in a generic way:
self.model_kwargs = get_default_args(self._trainer_cls.__init__)
self.trainer_kwargs = get_default_args(self._trainer_cls.fit)
If we don't remove those argument names we have problems during the init call:
> self._trainer = self._trainer_cls(self._model, **self.model_kwargs)
E TypeError: CrossEncoder.__init__() got multiple values for argument 'model_name'
...
> self._trainer = self._trainer_cls(self._model, **self.model_kwargs)
E TypeError: SentenceTransformer.__init__() got multiple values for argument 'model_name_or_path'
In the case of CrossEncoder
the first argument on init is positional (model_name
), while SentenceTransformer
uses a keyword argument (model_name_or_path
). As we are passing this argument by position in both cases I think
it's easier to remove from the general arguments.
The same happens with the fit
method:
if "train_objectives" in self.trainer_kwargs:
self.trainer_kwargs.pop("train_objectives")
if "train_dataloader" in self.trainer_kwargs:
self.trainer_kwargs.pop("train_dataloader")
E TypeError: SentenceTransformer.fit() got multiple values for argument 'train_objectives'
...
E TypeError: CrossEncoder.fit() got multiple values for argument 'train_dataloader'
Maybe there is a better way of doing this that I am not seeing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is the best wai to do it. I believe I misunderstood the difference of these variable per task type.
src/argilla/client/feedback/training/frameworks/sentence_transformers.py
Outdated
Show resolved
Hide resolved
I think it it could work, but I didn't really think about it and just assumed the It's working with this commit, but I would like to see if it can be simplified a bit, I think I could simplify some loops that may be slow for big datasets.
Regarding this one, I don't know that much about table QA, (if they are mature or powerful enough for example...) but from were I come from where the world is built around csv files and excel sheets, I can assure you that a lot of people would freak out and be looking forward to using it 😃. |
c88b643
to
50afd9b
Compare
1629948
to
cbcbc6c
Compare
3f23840
to
be99ae4
Compare
…roperly set the ones not included
ea0fab4
to
f680325
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! This is looking good. I haven't ran it myself locally, but it seems complete. Have you been able to test it out with any cool sentence similarity dataset on the Hub?
I found a few small nitpicks, mostly regarding formatting.
from sentence_transformers import InputExample | ||
|
||
# Use the first sample to decide what type of dataset to generate: | ||
sample_keys = data[0].keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential edge case: empty list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make a check first. Shouldn't this be checked in FeedbackDatasetBase.prepare_for_training
so that all the frameworks can be sure they start with at least one example?
if not len(data) > 0:
raise ValueError("The dataset must contain at least one sample to be able to train.")
tests/integration/client/feedback/training/test_sentence_transformers.py
Show resolved
Hide resolved
5088799
to
15e5666
Compare
046f50d
to
d0861bd
Compare
09d30ad
to
8931644
Compare
You can see an example in the docs both for from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask
dataset = FeedbackDataset.from_huggingface("plaguss/snli-small")
task = TrainingTask.for_sentence_similarity(
texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
label=dataset.question_by_name("label")
)
trainer = ArgillaTrainer(
dataset=dataset,
task=task,
framework="sentence-transformers",
framework_kwargs={"cross_encoder": False}
)
trainer.train(output_dir="my_sentence_transformer_model")
trainer.predict(
[
"Machine learning is so easy.",
["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
]
)
# [0.77857256, 0.4587626, 0.29062212] After training for an epoch and just a small number of samples in the dataset, the model yields quite similar results to those shown in this hugging face blog post sentence examples. |
006ec25
to
69ff41c
Compare
for more information, see https://pre-commit.ci
Codecov ReportPatch has no changes to coverable lines. 📢 Thoughts on this report? Let us know!. |
if as_argilla_records: | ||
raise NotImplementedError("To be done") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think that as_argilla_records makes sense in this type of models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this doesn't make sense. Before this was logical but within the FeedbackDataset it is less relevant to have this support. Good catch not to implement it.
_formatting_func_return_types = TrainingTaskForSentenceSimilarityFormat | ||
formatting_func: Callable[[Dict[str, Any]], Union[None, Tuple[str, str, str], Iterator[Tuple[str, str, str]]]] | ||
|
||
def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commit should address both this point and the previous one. The functions from the tests have been modified to generate data following test_trl.py
.
from sentence_transformers import InputExample | ||
|
||
# Use the first sample to decide what type of dataset to generate: | ||
sample_keys = data[0].keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make a check first. Shouldn't this be checked in FeedbackDatasetBase.prepare_for_training
so that all the frameworks can be sure they start with at least one example?
if not len(data) > 0:
raise ValueError("The dataset must contain at least one sample to be able to train.")
Sorry @tomaarsen I thought that I had answered all your points when I was closing them, but I didn't submit the reviews... |
0cbda25
into
argilla-io:feat/task-sentence-similarity
Hi @plaguss, Thanks a lot for the PR. It is looking good. I transferred ownership to a new branch to work on the fine-tuning of the docs as a final wrapup. Regards, |
Description
Includes integration with
sentence-transformers
to allow theArgillaTrainer
dealing withsentence-similarity
tasks.Closes #3316
Type of change
(Please delete options that are not relevant. Remember to title the PR according to the type of change)
How Has This Been Tested
(Please describe the tests that you ran to verify your changes. And ideally, reference
tests
)The corresponding tests can be found at:
Checklist
CHANGELOG.md
file (See https://keepachangelog.com/)