Commit
·
478d418
1
Parent(s):
b7e9713
Added new dataset
Browse files- scikit-learn/convert2onnx.py +11 -2
scikit-learn/convert2onnx.py
CHANGED
|
@@ -4,18 +4,27 @@
|
|
| 4 |
import argparse
|
| 5 |
import joblib
|
| 6 |
|
| 7 |
-
from sklearn.datasets import fetch_california_housing, load_diabetes, load_iris
|
| 8 |
from skl2onnx import convert_sklearn
|
| 9 |
from skl2onnx.common.data_types import FloatTensorType
|
| 10 |
|
| 11 |
|
| 12 |
def load_dataset(dataset_name):
|
| 13 |
if dataset_name == 'california':
|
|
|
|
| 14 |
dataset = fetch_california_housing()
|
| 15 |
elif dataset_name == 'diabetes':
|
|
|
|
| 16 |
dataset = load_diabetes()
|
| 17 |
elif dataset_name == 'iris':
|
|
|
|
| 18 |
dataset = load_iris()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
else:
|
| 20 |
raise ValueError("Invalid dataset name")
|
| 21 |
return dataset.data, dataset.target
|
|
@@ -53,7 +62,7 @@ python convert2onnx.py california adaboost_regressor.joblib adaboost_regressor.o
|
|
| 53 |
"""
|
| 54 |
if __name__ == "__main__":
|
| 55 |
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
|
| 56 |
-
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", or "
|
| 57 |
parser.add_argument('model_path', type=str, help='Path to the trained model file.')
|
| 58 |
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
|
| 59 |
args = parser.parse_args()
|
|
|
|
| 4 |
import argparse
|
| 5 |
import joblib
|
| 6 |
|
|
|
|
| 7 |
from skl2onnx import convert_sklearn
|
| 8 |
from skl2onnx.common.data_types import FloatTensorType
|
| 9 |
|
| 10 |
|
| 11 |
def load_dataset(dataset_name):
|
| 12 |
if dataset_name == 'california':
|
| 13 |
+
from sklearn.datasets import fetch_california_housing
|
| 14 |
dataset = fetch_california_housing()
|
| 15 |
elif dataset_name == 'diabetes':
|
| 16 |
+
from sklearn.datasets import load_diabetes
|
| 17 |
dataset = load_diabetes()
|
| 18 |
elif dataset_name == 'iris':
|
| 19 |
+
from sklearn.datasets import load_iris
|
| 20 |
dataset = load_iris()
|
| 21 |
+
elif dataset_name == "cardiotocography":
|
| 22 |
+
from sklearn.datasets import fetch_openml
|
| 23 |
+
dataset = fetch_openml(name=dataset_name, version=1, as_frame=False)
|
| 24 |
+
X, y = dataset.data, dataset.target
|
| 25 |
+
s = y == "3"
|
| 26 |
+
y = s.astype(int)
|
| 27 |
+
return X, y
|
| 28 |
else:
|
| 29 |
raise ValueError("Invalid dataset name")
|
| 30 |
return dataset.data, dataset.target
|
|
|
|
| 62 |
"""
|
| 63 |
if __name__ == "__main__":
|
| 64 |
parser = argparse.ArgumentParser(description='Converts a sklearn model to ONNX format.')
|
| 65 |
+
parser.add_argument('dataset_name', type=str, help='Name of the dataset. Choose from: "california", "diabetes", "iris" or "cardiotocography".')
|
| 66 |
parser.add_argument('model_path', type=str, help='Path to the trained model file.')
|
| 67 |
parser.add_argument('onnx_filename', type=str, help='The filename for the output ONNX file.')
|
| 68 |
args = parser.parse_args()
|