Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added plot_gp_slice.py #499

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

added plot_gp_slice.py #499

wants to merge 6 commits into from

Conversation

Jimbo994
Copy link
Collaborator

@Jimbo994 Jimbo994 commented Jan 10, 2025

Added a plot function to plot the "slice" of a GP prediction given some fixed input features, two input features that are varied between the bounds of the domain and the output feature of interest.

To run some quick tests on the ValueErrors:

x1 = ContinuousInput(key="x1", bounds=(0, 1))
x2 = ContinuousInput(key="x2", bounds=(0, 1))
x3 = ContinuousInput(key="x3", bounds=(0, 1))
x4 = ContinuousInput(key="x4", bounds=(0, 1))
x5 = ContinuousInput(key="x5", bounds=(0, 1))

y = ContinuousOutput(key="y")
y2 = ContinuousOutput(key="y2")


# Define the domain
inputs = Inputs(features=[x1, x2, x3, x4, x5])
outputs = Outputs(features=[y])
domain = Domain(inputs=inputs, outputs=outputs)

# generate some data
# Generate some synthetic data
data = pd.DataFrame({
    "x1": np.random.rand(100),
    "x2": np.random.rand(100),
    "x3": np.random.rand(100),
    "x4": np.random.rand(100),
    "x5": np.random.rand(100),
})

data2 = pd.DataFrame({
    "x1": np.random.rand(100),
    "x2": np.random.rand(100),
    "x3": np.random.rand(100),
    "x4": np.random.rand(100),
    "x5": np.random.rand(100),
})

# add a datapoint in the slice to test the plotting
new_data = pd.DataFrame([{
    "x1": 0.5,
    "x2": 0.5,
    "x3": 0.5,
    "x4": 0.5,
    "x5": 0.5,
}])

data = pd.concat([data, new_data], ignore_index=True)

data["y"] = data["x1"] + data["x2"] + np.random.normal(0, 0.1, 101)
data["valid_y"] = 1

# Define the surrogate model
surrogate_data = SingleTaskGPSurrogate(inputs=domain.inputs.get_by_keys(domain.inputs.get_keys()), outputs=domain.outputs)
surrogate_gp = surrogates.map(surrogate_data)
surrogate_gp.fit(data)

input_features = [x1, x2]
fixed_input_features = [x3, x4, x5]
fixed_values = [0.5, 0.5, 0.5]
fig, fig_sd = plot_gp_slice_plotly(surrogate_gp, fixed_input_features, fixed_values, input_features, y, observed_data=data)

assert isinstance(fig, go.Figure)
assert isinstance(fig_sd, go.Figure)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants