Skip to content

Commit

Permalink
feat(sdk): add method to use sdk charts (#217)
Browse files Browse the repository at this point in the history
* feat(sdk) chart in WIP

* feat(sdk): add binary dataQuaility charts

* feat(sdk): confusion_matrix(WIP)

* feat(sdk): add methods to show chart for every section

* feat(sdk): fix rounded values

* feat(sdk): ruff fixies

* feat(sdk): fix y_axis_label

* feat(sdk): remove ruff's check for multiclass methods

* feat(sdk): ruff format

* feat(sdk): fix poetry extras install and fix round on regressionlinear charts

* feat(sdk): add default case

* feat(sdk): ruff fix

* feat(sdk): check chart notebook
  • Loading branch information
dvalleri authored Dec 18, 2024
1 parent 2c713d5 commit baadd9e
Show file tree
Hide file tree
Showing 29 changed files with 1,797 additions and 1,033 deletions.
1 change: 1 addition & 0 deletions sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ requests = "^2.31.0"
pydantic = "^2.7.1"
boto3 = "^1.34.111"
pandas = "^2.2.2"
ipecharts = { version="1.0.8", optional = true }

[tool.poetry.group.dev.dependencies]
responses = "^0.25.0"
Expand Down
8 changes: 3 additions & 5 deletions sdk/radicalbit_platform_sdk/charts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from .chart_data import NumericalBarChartData, ConfusionMatrixChartData
from .chart import Chart
from .radicalbit_sdk_chart import RadicalbitChart,RbitChartData

__all__ = [
'ConfusionMatrixChartData',
'NumericalBarChartData',
'Chart'
'RadicalbitChart',
'RbitChartData'
]
Original file line number Diff line number Diff line change
@@ -1,8 +0,0 @@
from .binary_chart import BinaryChart
from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData

__all__ = [
'BinaryChart',
'BinaryDistributionChartData',
'BinaryLinearChartData'
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ipecharts import EChartsRawWidget

from ..utils import get_chart_header
from .binary_chart_data import BinaryDistributionChartData, BinaryLinearChartData
from ..common.utils import get_chart_header
from .binary_chart_data import BinaryDistributionChartData


class BinaryChart:
Expand All @@ -10,16 +10,35 @@ def __init__(self) -> None:

def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWidget:
assert len(data.reference_data) <= 2
assert len(data.y_axis_label) <= 2

if data.current_data:
assert len(data.current_data) <= 2

reference_json_data = [
binary_data.model_dump() for binary_data in data.reference_data
y_axis_label = [
metric['name']
for metric in [
binary_data.model_dump() for binary_data in data.reference_data
]
]
current_data_json = (
[binary_data.model_dump() for binary_data in data.current_data]

reference_data = [
{
'percentage': metric.percentage,
'value': metric.count,
'count': metric.count,
}
for metric in data.reference_data
]

current_data = (
[
{
'percentage': metric.percentage,
'value': metric.count,
'count': metric.count,
}
for metric in data.current_data
]
if data.current_data
else []
)
Expand All @@ -28,7 +47,7 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
'title': data.title,
'type': 'bar',
'itemStyle': {'color': '#9B99A1'},
'data': reference_json_data,
'data': reference_data,
'color': '#9B99A1',
'name': 'Reference',
'label': {
Expand All @@ -43,7 +62,7 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
'title': data.title + '_current',
'type': 'bar',
'itemStyle': {},
'data': current_data_json,
'data': current_data,
'color': '#3695d9',
'name': 'Current',
'label': {
Expand Down Expand Up @@ -79,7 +98,7 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9B99A1'},
'data': data.y_axis_label,
'data': y_axis_label,
},
'emphasis': {'disabled': True},
'barCategoryGap': '21%',
Expand All @@ -91,64 +110,3 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
option.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=option)

def linear_chart(self, data: BinaryLinearChartData) -> EChartsRawWidget:
reference_series_data = {
'name': 'Reference',
'type': 'line',
'lineStyle': {'width': 2.2, 'color': '#9B99A1', 'type': 'dotted'},
'symbol': 'none',
'data': data.reference_data,
'itemStyle': {'color': '#9B99A1'},
'endLabel': {'show': True, 'color': '#9B99A1'},
'color': '#9B99A1',
}

current_series_data = {
'name': data.title,
'type': 'line',
'lineStyle': {'width': 2.2, 'color': '#73B2E0'},
'symbol': 'none',
'data': data.current_data,
'itemStyle': {'color': '#73B2E0'},
}

series = [reference_series_data, current_series_data]

options = {
'tooltip': {
'trigger': 'axis',
'crosshairs': True,
'axisPointer': {'type': 'cross', 'label': {'show': True}},
},
'yAxis': {
'type': 'value',
'axisLabel': {'fontSize': 9, 'color': '#9b99a1'},
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}},
'scale': True,
},
'xAxis': {
'type': 'time',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9b99a1'},
'scale': True,
},
'grid': {
'bottom': 0,
'top': 32,
'left': 0,
'right': 64,
'containLabel': True,
},
'series': series,
'legend': {
'show': True,
'textStyle': {'color': '#9B99A1'},
},
}

options.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=options)
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,10 @@

from pydantic import BaseModel


class BinaryDistributionData(BaseModel):
percentage: float
count: float
value: float
from radicalbit_platform_sdk.models import ClassMetrics


class BinaryDistributionChartData(BaseModel):
title: str
y_axis_label: List[str]
reference_data: List[BinaryDistributionData]
current_data: Optional[List[BinaryDistributionData]] = None


class BinaryLinearChartData(BaseModel):
title: str
reference_data: List[List[str]]
current_data: List[List[str]]
reference_data: List[ClassMetrics]
current_data: Optional[List[ClassMetrics]] = None
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import reduce

from ipecharts import EChartsRawWidget
import numpy as np

from .chart_data import ConfusionMatrixChartData, NumericalBarChartData
from .chart_data import ConfusionMatrixChartData, LinearChartData, NumericalBarChartData
from .utils import get_chart_header, get_formatted_bucket_data


Expand Down Expand Up @@ -74,20 +76,24 @@ def numerical_bar_chart(self, data: NumericalBarChartData) -> EChartsRawWidget:
def confusion_matrix_chart(
self, data: ConfusionMatrixChartData
) -> EChartsRawWidget:
assert len(data.matrix) == len(data.axis_label) * len(
data.axis_label
), 'axis_label count and matrix item count are not compatibile'

np_matrix = np.matrix(data.matrix)

matrix_data = reduce(
lambda x, y: x + y,
[
[[xIdx, yIdx, value] for xIdx, value in enumerate(datas)]
for yIdx, datas in enumerate(reversed(data.matrix))
],
)

options = {
'yAxis': {
'type': 'category',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9B99A1'},
'data': data.axis_label,
'data': reversed(data.axis_label),
'name': 'Actual',
'nameGap': 25,
'nameLocation': 'middle',
Expand All @@ -103,7 +109,7 @@ def confusion_matrix_chart(
'color': '#9b99a1',
'rotate': 45,
},
'data': data.axis_label.reverse(),
'data': data.axis_label,
'name': 'Predicted',
'nameGap': 25,
'nameLocation': 'middle',
Expand All @@ -124,8 +130,69 @@ def confusion_matrix_chart(
'name': '',
'type': 'heatmap',
'label': {'show': True},
'data': data.matrix,
'data': matrix_data,
},
}

return EChartsRawWidget(option=options)

def linear_chart(self, data: LinearChartData) -> EChartsRawWidget:
reference_series_data = {
'name': 'Reference',
'type': 'line',
'lineStyle': {'width': 2.2, 'color': '#9B99A1', 'type': 'dotted'},
'symbol': 'none',
'data': data.reference_data,
'itemStyle': {'color': '#9B99A1'},
'endLabel': {'show': True, 'color': '#9B99A1'},
'color': '#9B99A1',
}

current_series_data = {
'name': data.title,
'type': 'line',
'lineStyle': {'width': 2.2, 'color': '#73B2E0'},
'symbol': 'none',
'data': data.current_data,
'itemStyle': {'color': '#73B2E0'},
}

series = [reference_series_data, current_series_data]

options = {
'tooltip': {
'trigger': 'axis',
'crosshairs': True,
'axisPointer': {'type': 'cross', 'label': {'show': True}},
},
'yAxis': {
'type': 'value',
'axisLabel': {'fontSize': 9, 'color': '#9b99a1'},
'splitLine': {'lineStyle': {'color': '#9f9f9f54'}},
'scale': True,
},
'xAxis': {
'type': 'time',
'axisTick': {'show': False},
'axisLine': {'show': False},
'splitLine': {'show': False},
'axisLabel': {'fontSize': 12, 'color': '#9b99a1'},
'scale': True,
},
'grid': {
'bottom': 0,
'top': 32,
'left': 0,
'right': 64,
'containLabel': True,
},
'series': series,
'legend': {
'show': True,
'textStyle': {'color': '#9B99A1'},
},
}

options.update(get_chart_header(title=data.title))

return EChartsRawWidget(option=options)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class NumericalBarChartData(BaseModel):
title: str
bucket_data: List[str]
bucket_data: List[float]
reference_data: List[float]
current_data: Optional[List[float]] = None

Expand All @@ -14,3 +14,9 @@ class ConfusionMatrixChartData(BaseModel):
axis_label: List[str]
matrix: List[List[float]]
color: Optional[List[str]] = ['#FFFFFF', '#9B99A1']


class LinearChartData(BaseModel):
title: str
reference_data: List[List[str]]
current_data: List[List[str]]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_formatted_bucket_data(bucket_data: List[str]) -> List[str]:
for idx, d in enumerate(bucket_data):
close_bracket = ']' if idx == bucket_data_len - 1 else ')'
if idx < bucket_data_len:
element = '[' + d + ',' + bucket_data[idx + 1] + close_bracket
element = '[' + str(d) + ',' + str(bucket_data[idx + 1]) + close_bracket
bucket_data_formatted.append(element)

return bucket_data_formatted
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from .multi_class_chart import MultiClassificationChart
from .multi_class_chart_data import MultiClassificationDistributionChartData, MultiClassificationLinearChartData, MultiClassificationLinearData

__all__ = [
'MultiClassificationChart',
'MultiClassificationDistributionChartData',
'MultiClassificationLinearChartData',
'MultiClassificationLinearData'
]
Loading

0 comments on commit baadd9e

Please sign in to comment.