Spaces:
Build error
Build error
Commit
·
e525bd5
1
Parent(s):
5ff9afc
fix: update example
Browse files
app.py
CHANGED
|
@@ -50,7 +50,7 @@ def load_examples():
|
|
| 50 |
|
| 51 |
|
| 52 |
# Create Gradio examples
|
| 53 |
-
examples = load_examples()
|
| 54 |
|
| 55 |
|
| 56 |
def process_fields(fields):
|
|
@@ -112,34 +112,44 @@ from gradio_client import Client
|
|
| 112 |
import argilla as rg
|
| 113 |
|
| 114 |
# Initialize Argilla client
|
| 115 |
-
|
|
|
|
| 116 |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
|
| 117 |
)
|
| 118 |
|
| 119 |
# Load the dataset
|
| 120 |
-
dataset =
|
| 121 |
-
|
| 122 |
-
#
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
|
|
|
| 126 |
payload = {
|
| 127 |
-
"records": [
|
| 128 |
-
"fields": [
|
| 129 |
-
"question":
|
|
|
|
|
|
|
| 130 |
}
|
| 131 |
|
| 132 |
-
|
| 133 |
-
client = Client("davidberenstein1957/distilabel-argilla-labeller")
|
| 134 |
-
|
| 135 |
-
result = client.predict(
|
| 136 |
-
records=json.dumps(payload["records"]),
|
| 137 |
-
example_records=json.dumps(payload["example_records"]),
|
| 138 |
-
fields=json.dumps(payload["fields"]),
|
| 139 |
-
question=json.dumps(payload["question"]),
|
| 140 |
-
api_name="/predict"
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
```
|
| 144 |
"""
|
| 145 |
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
# Create Gradio examples
|
| 53 |
+
examples = load_examples()[:1]
|
| 54 |
|
| 55 |
|
| 56 |
def process_fields(fields):
|
|
|
|
| 112 |
import argilla as rg
|
| 113 |
|
| 114 |
# Initialize Argilla client
|
| 115 |
+
gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller")
|
| 116 |
+
argilla_client = rg.Argilla(
|
| 117 |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
|
| 118 |
)
|
| 119 |
|
| 120 |
# Load the dataset
|
| 121 |
+
dataset = argilla_client.datasets(name="my_dataset", workspace="my_workspace")
|
| 122 |
+
|
| 123 |
+
# Get the field and question
|
| 124 |
+
field = dataset.settings.fields["text"]
|
| 125 |
+
question = dataset.settings.questions["sentiment"]
|
| 126 |
+
|
| 127 |
+
# Get completed and pending records
|
| 128 |
+
completed_records_filter = rg.Filter(("status", "==", "completed"))
|
| 129 |
+
pending_records_filter = rg.Filter(("status", "==", "pending"))
|
| 130 |
+
example_records = list(
|
| 131 |
+
dataset.records(
|
| 132 |
+
query=rg.Query(filter=completed_records_filter),
|
| 133 |
+
limit=5,
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
some_pending_records = list(
|
| 137 |
+
dataset.records(
|
| 138 |
+
query=rg.Query(filter=pending_records_filter),
|
| 139 |
+
limit=5,
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
|
| 143 |
+
# Process the records
|
| 144 |
payload = {
|
| 145 |
+
"records": [record.to_dict() for record in some_pending_records],
|
| 146 |
+
"fields": [field.serialize()],
|
| 147 |
+
"question": question.serialize(),
|
| 148 |
+
"example_records": [record.to_dict() for record in example_records],
|
| 149 |
+
"api_name": "/predict",
|
| 150 |
}
|
| 151 |
|
| 152 |
+
response = gradio_client.predict(**payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
```
|
| 154 |
"""
|
| 155 |
|