Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] argilla: Fix some from_hub method errors #5523

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions argilla/src/argilla/_exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from argilla._exceptions._serialization import * # noqa: F403
from argilla._exceptions._settings import * # noqa: F403
from argilla._exceptions._records import * # noqa: F403
from argilla._exceptions._hub import * # noqa: F403
16 changes: 14 additions & 2 deletions argilla/src/argilla/_exceptions/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argilla._exceptions import ArgillaError

__all__ = [
"ImportDatasetError",
"DatasetsServerException",
]

class DatasetsServerException(Exception):
message: str = "Error connecting to Hugging Face Hub datasets-server API"

class ImportDatasetError(ArgillaError):
def __init__(self, message: str = "Error importing dataset") -> None:
super().__init__(message)


class DatasetsServerException(ArgillaError):
def __init__(self, message: str = "Error connecting to Hugging Face Hub datasets-server API") -> None:
super().__init__(message)
23 changes: 16 additions & 7 deletions argilla/src/argilla/datasets/_io/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Type, Union

from argilla._exceptions import RecordsIngestionError, ArgillaError
from argilla._exceptions import RecordsIngestionError, ArgillaError, ImportDatasetError
from argilla._models import DatasetModel
from argilla.client import Argilla
from argilla.settings import Settings
Expand Down Expand Up @@ -82,11 +82,15 @@ def from_disk(

client = client or Argilla._get_default()

dataset_path, settings_path, records_path = cls._define_child_paths(path=path)
logging.info(f"Loading dataset from {dataset_path}")
logging.info(f"Loading settings from {settings_path}")
logging.info(f"Loading records from {records_path}")
dataset_model = cls._load_dataset_model(path=dataset_path)
try:
dataset_path, settings_path, records_path = cls._define_child_paths(path=path)
logging.info(f"Loading dataset from {dataset_path}")
logging.info(f"Loading settings from {settings_path}")
logging.info(f"Loading records from {records_path}")

dataset_model = cls._load_dataset_model(path=dataset_path)
except (NotADirectoryError, FileNotFoundError) as e:
raise ImportDatasetError(f"Error loading dataset from disk. {e}") from e

# Get the relevant workspace_id of the incoming dataset
if isinstance(workspace, str):
Expand All @@ -112,6 +116,9 @@ def from_disk(
dataset = cls.from_model(model=dataset_model, client=client)
else:
# Create a new dataset and load the settings and records
if not os.path.exists(settings_path):
raise ImportDatasetError(f"Settings file not found at {settings_path}")

dataset = cls.from_model(model=dataset_model, client=client)
dataset.settings = Settings.from_json(path=settings_path)
dataset.create()
Expand All @@ -121,8 +128,10 @@ def from_disk(
dataset.records.from_json(path=records_path)
except RecordsIngestionError as e:
raise RecordsIngestionError(
message="Error importing dataset records from disk. Records and datasets settings are not compatible."
message="Error importing dataset records from disk. "
"Records and datasets settings are not compatible."
) from e

return dataset

############################
Expand Down
8 changes: 6 additions & 2 deletions argilla/src/argilla/datasets/_io/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from datasets.data_files import EmptyDatasetError
from PIL import Image

from argilla._exceptions import ImportDatasetError
from argilla._exceptions._api import UnprocessableEntityError
from argilla._exceptions._records import RecordsIngestionError
from argilla._exceptions._settings import SettingsError
Expand Down Expand Up @@ -156,7 +157,7 @@ def from_hub(
dataset = cls.from_disk(
path=folder_path, workspace=workspace, name=name, client=client, with_records=with_records
)
except NotADirectoryError:
except ImportDatasetError:
from argilla import Settings

settings = Settings.from_hub(repo_id=repo_id)
Expand Down Expand Up @@ -192,8 +193,11 @@ def from_hub(
@staticmethod
def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
"""This method extracts the responses from a Hugging Face dataset and returns a list of `Record` objects"""
# THIS IS REQUIRED SINCE NAME IN ARGILLA ARE LOWERCASE. The Settings BaseModel models apply this logic
# Ideally, the restrictions in Argilla should be removed and the names should be case-insensitive
hf_dataset = hf_dataset.rename_columns({col: col.lower() for col in hf_dataset.column_names})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we apply the same logic as for the settings names. I only included the change here, but we could have problems importing case-insensitive records.

But maybe the solution would be relaxing this constraint on the server backend cc @jfcalvo . Anyway, I think we currently have a risk applying the name.lower() for all the settings models and we should change it.


# Identify columns that colunms that contain responses
# Identify columns that columns that contain responses
responses_columns = [col for col in hf_dataset.column_names if ".responses" in col]
response_questions = defaultdict(dict)
user_ids = {}
Expand Down
38 changes: 27 additions & 11 deletions argilla/src/argilla/records/_io/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union

from datasets import Dataset as HFDataset
from datasets import IterableDataset, Image, ClassLabel, Value
from datasets import Image, ClassLabel, Value

from argilla.records._io._generic import GenericIO
from argilla._helpers._media import pil_to_data_uri, uncast_image
Expand All @@ -32,7 +32,7 @@ def _cast_images_as_urls(hf_dataset: "HFDataset", columns: List[str]) -> "HFData

Parameters:
hf_dataset (HFDataset): The Hugging Face dataset to cast.
repo_id (str): The ID of the Hugging Face Hub repo.
columns (List[str]): The names of the columns containing the image features.

Returns:
HFDataset: The Hugging Face dataset with image features cast as URLs.
Expand Down Expand Up @@ -65,12 +65,23 @@ def _cast_classlabels_as_strings(hf_dataset: "HFDataset", columns: List[str]) ->
Returns:
HFDataset: The Hugging Face dataset with class label features cast as strings.
"""

def label_column2str(x: dict, column: str, features: dict) -> Dict[str, Union[str, None]]:
try:
value = features[column].int2str(x[column])
except Exception as ex:
warnings.warn(f"Could not cast {x[column]} to string. Error: {ex}")
value = None

return {column: value}

for column in columns:
features = hf_dataset.features.copy()
features[column] = Value("string")
hf_dataset = hf_dataset.map(
lambda x: {column: hf_dataset.features[column].int2str(x[column])}, features=features
label_column2str, fn_kwargs={"column": column, "features": hf_dataset.features}, features=features
)

return hf_dataset


Expand All @@ -91,6 +102,8 @@ def _uncast_uris_as_images(hf_dataset: "HFDataset", columns: List[str]) -> "HFDa
HFDataset: The Hugging Face dataset with image features cast as PIL images.
"""

casted_hf_dataset = hf_dataset

for column in columns:
features = hf_dataset.features.copy()
features[column] = Image()
Expand All @@ -109,6 +122,7 @@ def _uncast_uris_as_images(hf_dataset: "HFDataset", columns: List[str]) -> "HFDa
f"Image file not found for column {column}. Image will be persisted as string (URL, path, or base64)."
)
casted_hf_dataset = hf_dataset

return casted_hf_dataset


Expand Down Expand Up @@ -173,32 +187,34 @@ def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset:
return hf_dataset

@staticmethod
def _record_dicts_from_datasets(hf_dataset: HFDataset) -> List[Dict[str, Union[str, float, int, list]]]:
"""Creates a dictionaries from a HF dataset that can be passed to DatasetRecords.add or DatasetRecords.update.
def _record_dicts_from_datasets(hf_dataset: "HFDataset") -> List[Dict[str, Union[str, float, int, list]]]:
"""Creates a dictionaries from an HF dataset that can be passed to DatasetRecords.add or DatasetRecords.update.

Parameters:
dataset (Dataset): The dataset containing the records.
hf_dataset (HFDataset): The dataset containing the records.

Returns:
Generator[Dict[str, Union[str, float, int, list]], None, None]: A generator of dictionaries to be passed to DatasetRecords.add or DatasetRecords.update.
"""
hf_dataset = HFDatasetsIO.to_argilla(hf_dataset=hf_dataset)

try:
hf_dataset: IterableDataset = hf_dataset.to_iterable_dataset()
hf_dataset = hf_dataset.to_iterable_dataset()
except AttributeError:
pass

record_dicts = [example for example in hf_dataset]
return record_dicts

@staticmethod
def _uncast_argilla_attributes_to_datasets(hf_dataset: "HFDataset", schema: Dict) -> List[str]:
def _uncast_argilla_attributes_to_datasets(hf_dataset: "HFDataset", schema: Dict) -> "HFDataset":
"""Get the names of the Argilla fields that contain image data.

Parameters:
dataset (Dataset): The dataset to check.
hf_dataset (Dataset): The dataset to check.

Returns:
List[str]: The names of the Argilla fields that contain image data.
HFDataset: The dataset with argilla attributes uncasted.
"""

for attribute_type, uncaster in ATTRIBUTE_UNCASTERS.items():
Expand All @@ -211,7 +227,7 @@ def _uncast_argilla_attributes_to_datasets(hf_dataset: "HFDataset", schema: Dict
return hf_dataset

@staticmethod
def to_argilla(hf_dataset: "HFDataset") -> List[str]:
def to_argilla(hf_dataset: "HFDataset") -> "HFDataset":
"""Check if the Hugging Face dataset contains image features.

Parameters:
Expand Down
1 change: 1 addition & 0 deletions argilla/tests/integration/test_import_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def token():
return os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING")


@pytest.mark.skipif(not os.getenv("HF_TOKEN_ARGILLA_INTERNAL_TESTING"), reason="No HF token provided")
class TestImportFeaturesFromHub:
def test_import_records_from_datasets_with_classlabel(
self, token: str, dataset: rg.Dataset, client, mock_data: List[dict[str, Any]]
Expand Down