Skip to content

Commit

Permalink
Fix/restrict inputs types for client models (#2)
Browse files Browse the repository at this point in the history
* restrict input types to str or list of str

* include pydantic in requirements

* update doc string

* Reflect model change on server side

* add simple test when None is in inputs

* remove pydantic from server requirements (already included in general requirements)

* remove test for nested inputs, for now we do not support these
  • Loading branch information
David Fidalgo authored May 3, 2021
1 parent 5adc113 commit 2f71bd9
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 16 deletions.
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ install_requires =
requests ~= 2.25.0 # TODO We should use httpx instead to avoid a new dependency
# Client
pandas >=1.0.0,<2.0.0 # For data loading

pydantic ~= 1.8.1

[options.packages.find]
where = src
Expand All @@ -36,7 +36,6 @@ where = src
server =
# Basic dependencies
fastapi ~= 0.63.0
pydantic ~= 1.7.0
uvicorn[standard] ~= 0.13.4
elasticsearch >= 7.1.0,<8.0.0
smart-open
Expand Down
4 changes: 2 additions & 2 deletions src/rubrix/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TextClassificationRecord(BaseModel):
Attributes
----------
inputs : `Dict[str, Any]`
inputs : `Dict[str, Union[str, List[str]]]`
The inputs of the record
prediction : `List[Tuple[str, float]]`, optional
A list of tuples containing the predictions for the record. The first entry of the tuple is the predicted label,
Expand Down Expand Up @@ -80,7 +80,7 @@ class TextClassificationRecord(BaseModel):
The timestamp of the record. Default: None
"""

inputs: Dict[str, Any]
inputs: Dict[str, Union[str, List[str]]]

prediction: Optional[List[Tuple[str, float]]] = None
annotation: Optional[Union[str, List[str]]] = None
Expand Down
4 changes: 2 additions & 2 deletions src/rubrix/server/tasks/text_classification/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class CreationTextClassificationRecord(BaseRecord[TextClassificationAnnotation])
Attributes:
-----------
inputs: Dict[str, Any]
inputs: Dict[str, Union[str, List[str]]]
The input data text
multi_label: bool
Expand All @@ -93,7 +93,7 @@ class CreationTextClassificationRecord(BaseRecord[TextClassificationAnnotation])
The dictionary key must be aligned with provided record text. Optional
"""

inputs: Dict[str, Any]
inputs: Dict[str, Union[str, List[str]]]
multi_label: bool = False
explanation: Dict[str, List[TokenAttributions]] = None

Expand Down
7 changes: 7 additions & 0 deletions tests/client/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pydantic import ValidationError

from rubrix.client.models import TextClassificationRecord
from rubrix.client.models import TokenClassificationRecord
Expand Down Expand Up @@ -36,3 +37,9 @@ def test_token_classification_record(annotation, status, expected_status):
text="test text", tokens=["test", "text"], annotation=annotation, status=status
)
assert record.status == expected_status


def test_text_classification_record_none_inputs():
"""Test validation error for None in inputs"""
with pytest.raises(ValidationError):
TextClassificationRecord(inputs={"text": None})
10 changes: 0 additions & 10 deletions tests/server/text_classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@
)


def test_flatten_inputs():
data = {
"inputs": {
"mail": {"subject": "The mail subject", "body": "This is a large text body"}
}
}
record = TextClassificationRecord.parse_obj(data)
assert list(record.inputs.keys()) == ["mail.subject", "mail.body"]


def test_flatten_metadata():
data = {
"inputs": {"text": "bogh"},
Expand Down

0 comments on commit 2f71bd9

Please sign in to comment.