Skip to content

Commit

Permalink
[BUGFIX] argilla: Fix some from_hub method errors (#5523)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

Trying this code:

```python
import argilla as rg

client = rg.Argilla(...)

dataset = rg.Dataset.from_hub("google-research-datasets/circa")
```
I found several errors ([Here the source
dataset](https://huggingface.co/datasets/google-research-datasets/circa/viewer/default/train?f[goldstandard2][value]=0)):
- The mapping for records does not match since column names contain
uppercase characters
- Some values for the ClassValue column are unlabelled, making the
uncast process fail (-1)
- Importing cached datasets without config files fails when trying to
read the dataset JSON file.

So, this PR tackles all the problems commented below.



**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored and jfcalvo committed Sep 23, 2024
1 parent e1b2e6e commit 46caafe
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 22 deletions.
6 changes: 6 additions & 0 deletions argilla/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ These are the section headers that we use:

## [Unreleased]()

### Fixed

- Fixed `from_hub` errors when columns names contain uppercase letters. ([#5523](https://github.com/argilla-io/argilla/pull/5523))
- Fixed `from_hub` errors when class feature values contains unlabelled values. ([#5523](https://github.com/argilla-io/argilla/pull/5523))
- Fixed `from_hub` errors when loading cached datasets. ([#5523](https://github.com/argilla-io/argilla/pull/5523))

## [2.2.0](https://github.com/argilla-io/argilla/compare/v2.1.0...v2.2.0)

- Added new `ChatField` supporting chat messages. ([#5376](https://github.com/argilla-io/argilla/pull/5376))
Expand Down
1 change: 1 addition & 0 deletions argilla/src/argilla/_exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from argilla._exceptions._serialization import * # noqa: F403
from argilla._exceptions._settings import * # noqa: F403
from argilla._exceptions._records import * # noqa: F403
from argilla._exceptions._hub import * # noqa: F403
16 changes: 14 additions & 2 deletions argilla/src/argilla/_exceptions/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argilla._exceptions import ArgillaError

__all__ = [
"ImportDatasetError",
"DatasetsServerException",
]

class DatasetsServerException(Exception):
message: str = "Error connecting to Hugging Face Hub datasets-server API"

class ImportDatasetError(ArgillaError):
def __init__(self, message: str = "Error importing dataset") -> None:
super().__init__(message)


class DatasetsServerException(ArgillaError):
def __init__(self, message: str = "Error connecting to Hugging Face Hub datasets-server API") -> None:
super().__init__(message)
23 changes: 16 additions & 7 deletions argilla/src/argilla/datasets/_io/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union

from argilla._exceptions import RecordsIngestionError, ArgillaError
from argilla._exceptions import RecordsIngestionError, ArgillaError, ImportDatasetError
from argilla._models import DatasetModel
from argilla.client import Argilla
from argilla.settings import Settings
Expand Down Expand Up @@ -82,11 +82,15 @@ def from_disk(

client = client or Argilla._get_default()

dataset_path, settings_path, records_path = cls._define_child_paths(path=path)
logging.info(f"Loading dataset from {dataset_path}")
logging.info(f"Loading settings from {settings_path}")
logging.info(f"Loading records from {records_path}")
dataset_model = cls._load_dataset_model(path=dataset_path)
try:
dataset_path, settings_path, records_path = cls._define_child_paths(path=path)
logging.info(f"Loading dataset from {dataset_path}")
logging.info(f"Loading settings from {settings_path}")
logging.info(f"Loading records from {records_path}")

dataset_model = cls._load_dataset_model(path=dataset_path)
except (NotADirectoryError, FileNotFoundError) as e:
raise ImportDatasetError(f"Error loading dataset from disk. {e}") from e

# Get the relevant workspace_id of the incoming dataset
if isinstance(workspace, str):
Expand All @@ -112,6 +116,9 @@ def from_disk(
dataset = cls.from_model(model=dataset_model, client=client)
else:
# Create a new dataset and load the settings and records
if not os.path.exists(settings_path):
raise ImportDatasetError(f"Settings file not found at {settings_path}")

dataset = cls.from_model(model=dataset_model, client=client)
dataset.settings = Settings.from_json(path=settings_path)
dataset.create()
Expand All @@ -121,8 +128,10 @@ def from_disk(
dataset.records.from_json(path=records_path)
except RecordsIngestionError as e:
raise RecordsIngestionError(
message="Error importing dataset records from disk. Records and datasets settings are not compatible."
message="Error importing dataset records from disk. "
"Records and datasets settings are not compatible."
) from e

return dataset

############################
Expand Down
8 changes: 6 additions & 2 deletions argilla/src/argilla/datasets/_io/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from datasets.data_files import EmptyDatasetError
from PIL import Image

from argilla._exceptions import ImportDatasetError
from argilla._exceptions._api import UnprocessableEntityError
from argilla._exceptions._records import RecordsIngestionError
from argilla._exceptions._settings import SettingsError
Expand Down Expand Up @@ -156,7 +157,7 @@ def from_hub(
dataset = cls.from_disk(
path=folder_path, workspace=workspace, name=name, client=client, with_records=with_records
)
except NotADirectoryError:
except ImportDatasetError:
from argilla import Settings

settings = Settings.from_hub(repo_id=repo_id)
Expand Down Expand Up @@ -192,8 +193,11 @@ def from_hub(
@staticmethod
def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
"""This method extracts the responses from a Hugging Face dataset and returns a list of `Record` objects"""
# THIS IS REQUIRED SINCE NAME IN ARGILLA ARE LOWERCASE. The Settings BaseModel models apply this logic
# Ideally, the restrictions in Argilla should be removed and the names should be case-insensitive
hf_dataset = hf_dataset.rename_columns({col: col.lower() for col in hf_dataset.column_names})

# Identify columns that colunms that contain responses
# Identify columns that columns that contain responses
responses_columns = [col for col in hf_dataset.column_names if ".responses" in col]
response_questions = defaultdict(dict)
user_ids = {}
Expand Down
38 changes: 27 additions & 11 deletions argilla/src/argilla/records/_io/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

from datasets import Dataset as HFDataset
from datasets import IterableDataset, Image, ClassLabel, Value
from datasets import Image, ClassLabel, Value

from argilla.records._io._generic import GenericIO
from argilla._helpers._media import pil_to_data_uri, uncast_image
Expand All @@ -32,7 +32,7 @@ def _cast_images_as_urls(hf_dataset: "HFDataset", columns: List[str]) -> "HFData
Parameters:
hf_dataset (HFDataset): The Hugging Face dataset to cast.
repo_id (str): The ID of the Hugging Face Hub repo.
columns (List[str]): The names of the columns containing the image features.
Returns:
HFDataset: The Hugging Face dataset with image features cast as URLs.
Expand Down Expand Up @@ -65,12 +65,23 @@ def _cast_classlabels_as_strings(hf_dataset: "HFDataset", columns: List[str]) ->
Returns:
HFDataset: The Hugging Face dataset with class label features cast as strings.
"""

def label_column2str(x: dict, column: str, features: dict) -> Dict[str, Union[str, None]]:
try:
value = features[column].int2str(x[column])
except Exception as ex:
warnings.warn(f"Could not cast {x[column]} to string. Error: {ex}")
value = None

return {column: value}

for column in columns:
features = hf_dataset.features.copy()
features[column] = Value("string")
hf_dataset = hf_dataset.map(
lambda x: {column: hf_dataset.features[column].int2str(x[column])}, features=features
label_column2str, fn_kwargs={"column": column, "features": hf_dataset.features}, features=features
)

return hf_dataset


Expand All @@ -91,6 +102,8 @@ def _uncast_uris_as_images(hf_dataset: "HFDataset", columns: List[str]) -> "HFDa
HFDataset: The Hugging Face dataset with image features cast as PIL images.
"""

casted_hf_dataset = hf_dataset

for column in columns:
features = hf_dataset.features.copy()
features[column] = Image()
Expand All @@ -109,6 +122,7 @@ def _uncast_uris_as_images(hf_dataset: "HFDataset", columns: List[str]) -> "HFDa
f"Image file not found for column {column}. Image will be persisted as string (URL, path, or base64)."
)
casted_hf_dataset = hf_dataset

return casted_hf_dataset


Expand Down Expand Up @@ -173,32 +187,34 @@ def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset:
return hf_dataset

@staticmethod
def _record_dicts_from_datasets(hf_dataset: HFDataset) -> List[Dict[str, Union[str, float, int, list]]]:
"""Creates a dictionaries from a HF dataset that can be passed to DatasetRecords.add or DatasetRecords.update.
def _record_dicts_from_datasets(hf_dataset: "HFDataset") -> List[Dict[str, Union[str, float, int, list]]]:
"""Creates a dictionaries from an HF dataset that can be passed to DatasetRecords.add or DatasetRecords.update.
Parameters:
dataset (Dataset): The dataset containing the records.
hf_dataset (HFDataset): The dataset containing the records.
Returns:
Generator[Dict[str, Union[str, float, int, list]], None, None]: A generator of dictionaries to be passed to DatasetRecords.add or DatasetRecords.update.
"""
hf_dataset = HFDatasetsIO.to_argilla(hf_dataset=hf_dataset)

try:
hf_dataset: IterableDataset = hf_dataset.to_iterable_dataset()
hf_dataset = hf_dataset.to_iterable_dataset()
except AttributeError:
pass

record_dicts = [example for example in hf_dataset]
return record_dicts

@staticmethod
def _uncast_argilla_attributes_to_datasets(hf_dataset: "HFDataset", schema: Dict) -> List[str]:
def _uncast_argilla_attributes_to_datasets(hf_dataset: "HFDataset", schema: Dict) -> "HFDataset":
"""Get the names of the Argilla fields that contain image data.
Parameters:
dataset (Dataset): The dataset to check.
hf_dataset (Dataset): The dataset to check.
Returns:
List[str]: The names of the Argilla fields that contain image data.
HFDataset: The dataset with argilla attributes uncasted.
"""

for attribute_type, uncaster in ATTRIBUTE_UNCASTERS.items():
Expand All @@ -211,7 +227,7 @@ def _uncast_argilla_attributes_to_datasets(hf_dataset: "HFDataset", schema: Dict
return hf_dataset

@staticmethod
def to_argilla(hf_dataset: "HFDataset") -> List[str]:
def to_argilla(hf_dataset: "HFDataset") -> "HFDataset":
"""Check if the Hugging Face dataset contains image features.
Parameters:
Expand Down
3 changes: 3 additions & 0 deletions argilla/src/argilla/settings/_io/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def _define_settings_from_features(
feature_type = _map_feature_type(feature)
attribute_definition = _map_attribute_type(feature_mapping.get(name))

# Lowercase the name to avoid case sensitivity issues when mapping the features to the settings.
name = name.lower()

if feature_type == FeatureType.CHAT:
fields.append(ChatField(name=name, required=False))
elif feature_type == FeatureType.TEXT:
Expand Down
19 changes: 19 additions & 0 deletions argilla/tests/integration/test_import_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def token():
return os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING")


@pytest.mark.skipif(not os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING"), reason="No HF token provided")
class TestImportFeaturesFromHub:
def test_import_records_from_datasets_with_classlabel(
self, token: str, dataset: rg.Dataset, client, mock_data: List[dict[str, Any]]
Expand Down Expand Up @@ -111,3 +112,21 @@ def test_import_records_from_datasets_with_classlabel(

assert exported_dataset.features["label.suggestion"].names == ["positive", "negative"]
assert exported_dataset["label.suggestion"] == [0, 1, 0]

def test_import_from_hub_with_upper_case_columns(self, client: rg.Argilla, token: str, dataset_name: str):
created_dataset = rg.Dataset.from_hub(
"argilla-internal-testing/test_import_from_hub_with_upper_case_columns",
token=token,
name=dataset_name,
)

assert created_dataset.settings.fields[0].name == "text"
assert list(created_dataset.records)[0].fields["text"] == "Hello World, how are you?"

def test_import_from_hub_with_unlabelled_classes(self, client: rg.Argilla, token: str, dataset_name: str):
created_dataset = rg.Dataset.from_hub(
"argilla-internal-testing/test_import_from_hub_with_unlabelled_classes", token=token, name=dataset_name
)

assert created_dataset.settings.fields[0].name == "text"
assert list(created_dataset.records)[0].fields["text"] == "Hello World, how are you?"

0 comments on commit 46caafe

Please sign in to comment.