Skip to content

Commit b5c9524

Browse files
htahir1schustmi
andauthored
Update scikit-learn requirement in SklearnIntegration (#3551)
* Update scikit-learn requirement in SklearnIntegration * Fix function annotations in SklearnIntegration class and dynamic_importer step * Update return type annotation for get_data_from_api function * Try installing scikit image in tests * Correctly sample to 8x8 * Remove dev installation --------- Co-authored-by: Michael Schuster <michael.schuster.ffb@googlemail.com>
1 parent 31b36c2 commit b5c9524

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

src/zenml/integrations/sklearn/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ class SklearnIntegration(Integration):
2121
"""Definition of sklearn integration for ZenML."""
2222

2323
NAME = SKLEARN
24-
REQUIREMENTS = ["scikit-learn", "scikit-image"]
24+
REQUIREMENTS = ["scikit-learn"]
2525

2626
@classmethod
2727
def activate(cls) -> None:
2828
"""Activates the integration."""
2929
from zenml.integrations.sklearn import materializers # noqa
3030

31-
SklearnIntegration.check_installation()
3231

32+
SklearnIntegration.check_installation()

tests/integration/examples/mlflow/steps/dynamic_importer_step.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""This step downloads the latest data from a mock API and returns it as a numpy array."""
2+
13
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
24
#
35
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,13 +16,17 @@
1416
import numpy as np # type: ignore [import]
1517
import pandas as pd # type: ignore [import]
1618
import requests # type: ignore [import]
17-
from skimage.transform import resize
1819
from typing_extensions import Annotated
1920

2021
from zenml import step
2122

2223

23-
def get_data_from_api():
24+
def get_data_from_api() -> Annotated[np.ndarray, "api_data"]:
25+
"""Downloads the latest data from a mock API.
26+
27+
Returns:
28+
Annotated[np.ndarray, "data"]: Downsampled image data as a numpy array.
29+
"""
2430
url = (
2531
"https://storage.googleapis.com/zenml-public-bucket/mnist"
2632
"/mnist_handwritten_test.json"
@@ -30,12 +36,10 @@ def get_data_from_api():
3036
data = df["image"].map(lambda x: np.array(x)).values
3137
data = np.array(
3238
[
33-
resize(
34-
x.reshape(28, 28).astype("uint8"),
35-
(8, 8),
36-
anti_aliasing=False,
37-
preserve_range=True,
38-
)
39+
# Pad the image to 32x32 to enable downsampling to 8x8
40+
np.pad(x.reshape(28, 28).astype("float64"), 2)[
41+
::4, ::4
42+
] # Downsample to 8x8 by taking every 4th pixel
3943
for x in data
4044
]
4145
)
@@ -44,6 +48,10 @@ def get_data_from_api():
4448

4549
@step(enable_cache=False)
4650
def dynamic_importer() -> Annotated[np.ndarray, "data"]:
47-
"""Downloads the latest data from a mock API."""
51+
"""Downloads the latest data from a mock API.
52+
53+
Returns:
54+
Annotated[np.ndarray, "data"]: Downsampled image data as a numpy array.
55+
"""
4856
data = get_data_from_api()
4957
return data

0 commit comments

Comments
 (0)