Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added colorscale to axes.plot() #3148

Merged
merged 19 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions manim/mobject/graphing/coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from manim.mobject.text.tex_mobject import MathTex
from manim.mobject.three_d.three_dimensions import Surface
from manim.mobject.types.vectorized_mobject import (
CurvesAsSubmobjects,
VDict,
VectorizedPoint,
VGroup,
Expand All @@ -47,6 +48,7 @@
ManimColor,
ParsableManimColor,
color_gradient,
interpolate_color,
invert_color,
)
from manim.utils.config_ops import merge_dicts_recursively, update_dict_recursively
Expand Down Expand Up @@ -619,6 +621,8 @@ def plot(
function: Callable[[float], float],
x_range: Sequence[float] | None = None,
use_vectorized: bool = False,
colorscale: Union[Iterable[Color], Iterable[Color, float]] | None = None,
colorscale_axis: int = 1,
**kwargs: Any,
) -> ParametricFunction:
"""Generates a curve based on a function.
Expand All @@ -632,6 +636,12 @@ def plot(
use_vectorized
Whether to pass in the generated t value array to the function. Only use this if your function supports it.
Output should be a numpy array of shape ``[y_0, y_1, ...]``
colorscale
Colors of the function. Optional parameter used when coloring a function by values. Passing a list of colors
and a colorscale_axis will color the function by y-value. Passing a list of tuples in the form ``(color, pivot)``
allows user-defined pivots where the color transitions.
colorscale_axis
Defines the axis on which the colorscale is applied (0 = x, 1 = y), default is y-axis (1).
kwargs
Additional parameters to be passed to :class:`~.ParametricFunction`.

Expand Down Expand Up @@ -710,7 +720,57 @@ def log_func(x):
use_vectorized=use_vectorized,
**kwargs,
)

graph.underlying_function = function

if colorscale:
if type(colorscale[0]) in (list, tuple):
new_colors, pivots = [
[i for i, j in colorscale],
[j for i, j in colorscale],
]
else:
new_colors = colorscale

ranges = [self.x_range, self.y_range]
pivot_min = ranges[colorscale_axis][0]
pivot_max = ranges[colorscale_axis][1]
pivot_frequency = (pivot_max - pivot_min) / (len(new_colors) - 1)
pivots = np.arange(
start=pivot_min,
stop=pivot_max + pivot_frequency,
step=pivot_frequency,
)

resolution = 0.01 if len(x_range) == 2 else x_range[2]
sample_points = np.arange(x_range[0], x_range[1] + resolution, resolution)
color_list = []
for samp_x in sample_points:
axis_value = (samp_x, function(samp_x))[colorscale_axis]
if axis_value <= pivots[0]:
color_list.append(new_colors[0])
elif axis_value >= pivots[-1]:
color_list.append(new_colors[-1])
else:
for i, pivot in enumerate(pivots):
if pivot > axis_value:
color_index = (axis_value - pivots[i - 1]) / (
pivots[i] - pivots[i - 1]
)
color_index = min(color_index, 1)
mob_color = interpolate_color(
new_colors[i - 1],
new_colors[i],
color_index,
)
color_list.append(mob_color)
break
if config.renderer == RendererType.OPENGL:
graph.set_color(color_list)
else:
graph.set_stroke(color_list)
graph.set_sheen_direction(RIGHT)

return graph

def plot_implicit_curve(
Expand Down
Binary file not shown.
Binary file not shown.
34 changes: 34 additions & 0 deletions tests/opengl/test_coordinate_system_opengl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from manim import LEFT, ORIGIN, PI, UR, Axes, Circle, ComplexPlane
from manim import CoordinateSystem as CS
from manim import NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig
from manim.utils.color import BLUE, GREEN, ORANGE, RED, YELLOW
from manim.utils.testing.frames_comparison import frames_comparison

__module_test__ = "coordinate_system_opengl"


def test_initial_config(using_opengl_renderer):
Expand Down Expand Up @@ -126,3 +130,33 @@ def test_input_to_graph_point(using_opengl_renderer):
# test the line_graph implementation
position = np.around(ax.input_to_graph_point(x=PI, graph=line_graph), decimals=4)
np.testing.assert_array_equal(position, (2.6928, 1.2876, 0))


@frames_comparison
def test_gradient_line_graph_x_axis(scene, using_opengl_renderer):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=0,
)

scene.add(axes, curve)


@frames_comparison
def test_gradient_line_graph_y_axis(scene, using_opengl_renderer):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=1,
)

scene.add(axes, curve)
Binary file not shown.
Binary file not shown.
30 changes: 30 additions & 0 deletions tests/test_graphical_units/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,33 @@ def test_number_plane_log(scene):
)

scene.add(VGroup(plane1, plane2).arrange())


@frames_comparison
def test_gradient_line_graph_x_axis(scene):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=0,
)

scene.add(axes, curve)


@frames_comparison
def test_gradient_line_graph_y_axis(scene):
"""Test that using `colorscale` generates a line whose gradient matches the y-axis"""
axes = Axes(x_range=[-3, 3], y_range=[-3, 3])

curve = axes.plot(
lambda x: 0.1 * x**3,
x_range=(-3, 3, 0.001),
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
colorscale_axis=1,
)

scene.add(axes, curve)
Loading