Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sentence-similarity support to the ArgillaTrainer #3739

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Sep 8, 2023

Description

Includes integration with sentence-transformers to allow the ArgillaTrainer dealing with sentence-similarity tasks.

Closes #3316

Type of change

(Please delete options that are not relevant. Remember to title the PR according to the type of change)

  • New feature (non-breaking change which adds functionality)

How Has This Been Tested

(Please describe the tests that you ran to verify your changes. And ideally, reference tests)

The corresponding tests can be found at:

  • tests/integration/client/feedback/training/test_sentence_transformers.py

Checklist

  • I added relevant documentation
  • I followed the style guidelines of this project
  • I did a self-review of my code
  • I made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • I filled out the contributor form (see text above)
  • I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)


if "model" in self.model_kwargs:
self._model = self.model_kwargs["model"]
# self._model = self.model_kwargs["model_name_or_path"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? Do we use model_name_or_path or model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model is the one used, I forgot to remove the model_name_or_path. Internally SentenceTransformer uses model_name_or_path as keyword argument, while CrossEncoder uses model_name.
I will take a closer look, the classes aren't completely equal and maybe I lost more arguments

Copy link
Member

@davidberenstein1957 davidberenstein1957 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @plaguss it looks good already but I left some small initial remarks.

Also, I had some questions:
Do you see this integration potentially also working with an out-of-box solution using Fields and Questions as defined for TextClassification? If so, why, else why not?

What is, according to you, the most interesting Hugging Face nlp task to add to the trainer and why?

self._cross_encoder = cross_encoder

if model is None:
if not self._cross_encoder:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could add a warning here about the usage of a pre-defined model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does something like this work?:

self._logger.warning(f"No model was selected, using pre-defined `{'Cross-Encoder' if self._cross_encoder else 'Bi-Encoder'}` with `{self._model}`.")

https://huggingface.co/blog/how-to-train-sentence-transformers#how-to-prepare-your-dataset-for-training-a-sentence-transformers-model
"""

format: Union[Dict[str, Union[float, int]], Dict[str, str]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want to allow for returning a list of samples because each record might contain multiple comparisons.

_formatting_func_return_types = TrainingTaskForSentenceSimilarityFormat
formatting_func: Callable[[Dict[str, Any]], Union[None, Tuple[str, str, str], Iterator[Tuple[str, str, str]]]]

def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be changed a bit to allow for returning a list of samples too from the formatting_func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit should address both this point and the previous one. The functions from the tests have been modified to generate data following test_trl.py.

def _train_test_split(self, data: List[dict], train_size: float, seed: int) -> Tuple[List[dict], List[dict]]:
from sklearn.model_selection import train_test_split

return train_test_split(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider stratifying this split too? I can imagine the model might over-under estimate when not doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will add it for the datasets that include integer labels.

if "model_name_or_path" in self.model_kwargs:
self.model_kwargs.pop("model_name_or_path")
if "model_name" in self.model_kwargs:
self.model_kwargs.pop("model_name")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we also using model_name and model_name_or_path here?

Copy link
Contributor Author

@plaguss plaguss Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we grab the default args it is done for SentenceTransformer and CrossEncoder in a generic way:

self.model_kwargs = get_default_args(self._trainer_cls.__init__)
self.trainer_kwargs = get_default_args(self._trainer_cls.fit)

If we don't remove those argument names we have problems during the init call:

>       self._trainer = self._trainer_cls(self._model, **self.model_kwargs)
E       TypeError: CrossEncoder.__init__() got multiple values for argument 'model_name'
...
>       self._trainer = self._trainer_cls(self._model, **self.model_kwargs)
E       TypeError: SentenceTransformer.__init__() got multiple values for argument 'model_name_or_path'

In the case of CrossEncoder the first argument on init is positional (model_name), while SentenceTransformer
uses a keyword argument (model_name_or_path). As we are passing this argument by position in both cases I think
it's easier to remove from the general arguments.

The same happens with the fit method:

if "train_objectives" in self.trainer_kwargs:
    self.trainer_kwargs.pop("train_objectives")
if "train_dataloader" in self.trainer_kwargs:
    self.trainer_kwargs.pop("train_dataloader")
E       TypeError: SentenceTransformer.fit() got multiple values for argument 'train_objectives'
...
E       TypeError: CrossEncoder.fit() got multiple values for argument 'train_dataloader'

Maybe there is a better way of doing this that I am not seeing?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the best wai to do it. I believe I misunderstood the difference of these variable per task type.

@plaguss
Copy link
Contributor Author

plaguss commented Sep 8, 2023

Thanks for the PR @plaguss it looks good already but I left some small initial remarks.

Also, I had some questions: Do you see this integration potentially also working with an out-of-box solution using Fields and Questions as defined for TextClassification? If so, why, else why not?

I think it it could work, but I didn't really think about it and just assumed the formatting_func would do the job. Sentence transformers allows training with different types of datasets, which allow for different training strategies depending on the dataset.
I think it just needs some additional checking with respect to the TextClassification case, but let me take a look with some to some example datasets and think of an example.

It's working with this commit, but I would like to see if it can be simplified a bit, I think I could simplify some loops that may be slow for big datasets.

What is, according to you, the most interesting Hugging Face nlp task to add to the trainer and why?

Regarding this one, I don't know that much about table QA, (if they are mature or powerful enough for example...) but from were I come from where the world is built around csv files and excel sheets, I can assure you that a lot of people would freak out and be looking forward to using it 😃.

@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from c88b643 to 50afd9b Compare September 9, 2023 16:58
@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from 1629948 to cbcbc6c Compare September 9, 2023 17:14
@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from 3f23840 to be99ae4 Compare September 9, 2023 18:27
@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from ea0fab4 to f680325 Compare September 11, 2023 06:47
Copy link
Contributor

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! This is looking good. I haven't ran it myself locally, but it seems complete. Have you been able to test it out with any cool sentence similarity dataset on the Hub?
I found a few small nitpicks, mostly regarding formatting.

from sentence_transformers import InputExample

# Use the first sample to decide what type of dataset to generate:
sample_keys = data[0].keys()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: empty list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a check first. Shouldn't this be checked in FeedbackDatasetBase.prepare_for_training so that all the frameworks can be sure they start with at least one example?

if not len(data) > 0:
    raise ValueError("The dataset must contain at least one sample to be able to train.")

@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from 5088799 to 15e5666 Compare September 11, 2023 14:36
@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from 046f50d to d0861bd Compare September 11, 2023 14:51
@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from 09d30ad to 8931644 Compare September 11, 2023 15:21
@plaguss
Copy link
Contributor Author

plaguss commented Sep 11, 2023

Hey! This is looking good. I haven't ran it myself locally, but it seems complete. Have you been able to test it out with any cool sentence similarity dataset on the Hub? I found a few small nitpicks, mostly regarding formatting.

You can see an example in the docs both for Bi-Encoder and Cross-Encoder, in a subset of the snli dataset 😃 :

from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface("plaguss/snli-small")

task = TrainingTask.for_sentence_similarity(
    texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
    label=dataset.question_by_name("label")
)
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="sentence-transformers",
    framework_kwargs={"cross_encoder": False}
)
trainer.train(output_dir="my_sentence_transformer_model")
trainer.predict(
    [
        "Machine learning is so easy.",
        ["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
    ]
)
# [0.77857256, 0.4587626, 0.29062212]

After training for an epoch and just a small number of samples in the dataset, the model yields quite similar results to those shown in this hugging face blog post sentence examples.

@plaguss plaguss marked this pull request as ready for review September 11, 2023 15:33
@plaguss plaguss force-pushed the feat/task-sentence-similarity branch from 006ec25 to 69ff41c Compare September 11, 2023 16:11
@codecov
Copy link

codecov bot commented Sep 11, 2023

Codecov Report

Patch has no changes to coverable lines.

📢 Thoughts on this report? Let us know!.

Comment on lines +211 to +212
if as_argilla_records:
raise NotImplementedError("To be done")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think that as_argilla_records makes sense in this type of models?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this doesn't make sense. Before this was logical but within the FeedbackDataset it is less relevant to have this support. Good catch not to implement it.

_formatting_func_return_types = TrainingTaskForSentenceSimilarityFormat
formatting_func: Callable[[Dict[str, Any]], Union[None, Tuple[str, str, str], Iterator[Tuple[str, str, str]]]]

def _format_data(self, dataset: "FeedbackDataset") -> List[Dict[str, Any]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit should address both this point and the previous one. The functions from the tests have been modified to generate data following test_trl.py.

from sentence_transformers import InputExample

# Use the first sample to decide what type of dataset to generate:
sample_keys = data[0].keys()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a check first. Shouldn't this be checked in FeedbackDatasetBase.prepare_for_training so that all the frameworks can be sure they start with at least one example?

if not len(data) > 0:
    raise ValueError("The dataset must contain at least one sample to be able to train.")

@plaguss
Copy link
Contributor Author

plaguss commented Sep 12, 2023

Sorry @tomaarsen I thought that I had answered all your points when I was closing them, but I didn't submit the reviews...

@davidberenstein1957 davidberenstein1957 changed the base branch from develop to feat/task-sentence-similarity September 13, 2023 14:05
@davidberenstein1957 davidberenstein1957 merged commit 0cbda25 into argilla-io:feat/task-sentence-similarity Sep 13, 2023
@davidberenstein1957
Copy link
Member

Hi @plaguss,

Thanks a lot for the PR. It is looking good. I transferred ownership to a new branch to work on the fine-tuning of the docs as a final wrapup.

Regards,
David

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] add sentence-similarity support to the ArgillaTrainer
3 participants