Skip to content

Commit

Permalink
Unified interface for all cases, removed unnecessary attributes to im…
Browse files Browse the repository at this point in the history
…prove loading time of datasets
  • Loading branch information
dvmoritzschoefl committed Feb 10, 2025
1 parent 7ba84a8 commit 97fd3d0
Show file tree
Hide file tree
Showing 10 changed files with 3,310 additions and 1,997,601 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
"jszip": "^3.10.1",
"lineupjs": "4.12.0",
"lodash": "~4.17.21",
"ml-pca": "^4.1.1",
"plotly.js-dist-min": "~2.12.1",
"rbush": "^4.0.1",
"react": "~18.3.1",
Expand Down
172 changes: 172 additions & 0 deletions src/demo/Cases/CimeFlameTree.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import React from 'react';

import { Button, Group, Select, Slider, Switch, Text } from '@mantine/core';
import * as d3 from 'd3v7';
import { map, uniq } from 'lodash';
import * as vsup from 'vsup';

import { FlameTree, FlameTreeAPI } from '../FlameTree';
import { AggregateSelect } from '../FlameTree/AggregateSelect';
import { CutoffSlider } from '../FlameTree/CutoffSlider';
import { useCutoffFilter, useStateReset } from '../FlameTree/hooks';
import { AggregationType, ParameterColumn, adjustDomain, aggregateBy, createParameterHierarchy } from '../FlameTree/math';

export default function CimeFlameTree({
dataset,
columnKeys,
mode,
maxIterations,
}: {
dataset: Record<string, unknown>[];
columnKeys: string[];
mode: 'experiment' | 'prediction';
maxIterations?: number;
}) {
const dataKey = mode === 'experiment' ? 'measured_yield' : 'meas_yield';

const definitions = React.useMemo(() => {
return columnKeys.map((key) => {
return {
key,
domain: uniq(map(dataset, key)),
type: 'categorical',
} as ParameterColumn;
});
}, [columnKeys, dataset]);

const [iteration, setIteration] = React.useState<number>(0);
const [layering, setLayering] = React.useState<string[]>(definitions.map((column) => column.key));
const [aggregation, setAggregation] = React.useState<AggregationType>('max');
const [uncertaintyAggregation, setUncertaintyAggregation] = React.useState<AggregationType>('min');
const [coloring, setColoring] = React.useState<'yield' | 'yield+uncertainty'>('yield+uncertainty');

const bins = React.useMemo(() => {
return createParameterHierarchy(definitions, dataset, layering, [0, 100], (items) => {
return {
value:
mode === 'experiment'
? aggregateBy(aggregation, map(items, dataKey) as number[])
: aggregateBy(aggregation, map(items, `pred_yield_mean_${iteration}`) as number[]),
uncertainty: mode === 'experiment' ? 0 : aggregateBy(uncertaintyAggregation, map(items, `pred_yield_var_${iteration}`) as number[]),
};
});
}, [aggregation, dataKey, dataset, definitions, iteration, layering, mode, uncertaintyAggregation]);

const experiments = React.useMemo(() => {
return dataset.filter((entry) => (entry.experiment_cycle as number) <= iteration && entry[dataKey] !== -1);
}, [dataKey, dataset, iteration]);

const scales = React.useMemo(() => {
const binDomain = d3.extent(Object.values(bins).map((bin) => bin.value.value as number)) as number[];
const binVariance = d3.extent(Object.values(bins).map((bin) => bin.value.uncertainty as number)) as number[];
// const dataDomain = d3.extent(dataset.map((entry) => entry[dataKey] as number)) as number[];
const dataDomain = [0, 100];

let squareQuantization: any;

if (mode === 'experiment') {
squareQuantization = vsup.squareQuantization().n(5).valueDomain(binDomain).uncertaintyDomain([0, 1]);
} else if (mode === 'prediction') {
const yieldDomain = d3.extent([...binDomain, ...dataDomain]) as number[];
squareQuantization = vsup.squareQuantization().n(10).valueDomain(yieldDomain).uncertaintyDomain(binVariance);
}

return {
squareScale: vsup.scale().quantize(squareQuantization).range(d3.interpolateCividis),
cutoffDomain: adjustDomain(binDomain),
};
}, [bins, mode]);

const [cutoff, setCutoff] = React.useState<number>(scales.cutoffDomain[0]!);

useStateReset(() => {
setCutoff(scales.cutoffDomain[0]!);
}, scales);

const filter = useCutoffFilter(bins, 'value', cutoff);

const [synchronizeHover, setSynchronizeHover] = React.useState<boolean>(true);

const apiRef = React.useRef<FlameTreeAPI>();

return (
<div>
<FlameTree
bins={bins}
definitions={definitions}
layering={layering}
setLayering={setLayering}
experiments={experiments}
filter={filter}
itemHeight={90}
apiRef={apiRef}
synchronizeHover={synchronizeHover}
colorScale={(item) => {
if (coloring === 'yield') {
return scales.squareScale(item.value as number, 0);
}

if (coloring === 'yield+uncertainty') {
return scales.squareScale(item.value as number, item.uncertainty as number);
}

return 'black';
}}
experimentsColorScale={(item) => {
return mode === 'experiment' ? scales.squareScale(item[dataKey], 0) : scales.squareScale(item[dataKey], 0);
}}
>
<FlameTree.Toolbar>
<Group align="flex-end" gap="xl">
<Select
label="Coloring"
value={coloring}
onChange={setColoring as (v: string | null) => void}
data={[
{
label: 'Yield',
value: 'yield',
},
{
label: 'Yield + Uncertainty',
value: 'yield+uncertainty',
},
]}
/>

<AggregateSelect label="Value aggregation" aggregation={aggregation} setAggregation={setAggregation} />
{mode === 'prediction' ? (
<AggregateSelect label="Uncertainty aggregation" aggregation={uncertaintyAggregation} setAggregation={setUncertaintyAggregation} />
) : null}

{mode === 'prediction' ? (
<Group mb={8}>
<Text size="sm">Iteration</Text>
<Slider value={iteration} onChange={setIteration} min={0} max={maxIterations} w={200} />
</Group>
) : null}

<Switch
label="Synchronize hover"
mb={8}
checked={synchronizeHover}
onChange={(event) => {
setSynchronizeHover(event.currentTarget.checked);
}}
/>

<CutoffSlider mb={8} domain={scales.cutoffDomain} value={cutoff} onChange={setCutoff} />

<Button
onClick={() => {
apiRef.current?.resetZoom();
}}
>
Reset zoom
</Button>
</Group>
</FlameTree.Toolbar>
</FlameTree>
</div>
);
}
120 changes: 6 additions & 114 deletions src/demo/Cases/FlameCase1.tsx
Original file line number Diff line number Diff line change
@@ -1,122 +1,14 @@
import React from 'react';

import { css } from '@emotion/css';
import { Group, Switch, Text } from '@mantine/core';
import * as d3 from 'd3v7';
import { map, uniq } from 'lodash';
import * as vsup from 'vsup';

import { FlameTree } from '../FlameTree';
import { AggregateSelect } from '../FlameTree/AggregateSelect';
import { CutoffSlider } from '../FlameTree/CutoffSlider';
import { useCutoffFilter, useStateReset } from '../FlameTree/hooks';
import { AggregationType, ParameterColumn, adjustDomain, aggregateBy, createParameterHierarchy } from '../FlameTree/math';
import { TooltipContent, TooltipContentBin } from '../FlameTree/TooltipContent';
import CimeFlameTree from './CimeFlameTree';

const { UseCase1 } = await import('./case_study_1');

export default function FlameCase1() {
const definitions = React.useMemo(() => {
const ArylColumn: ParameterColumn = {
key: 'aryl_halide_file_name_exp_param',
domain: uniq(map(UseCase1, 'aryl_halide_file_name_exp_param')),
type: 'categorical',
};

const AdditiveColumn: ParameterColumn = {
key: 'additive_file_name_exp_param',
domain: uniq(map(UseCase1, 'additive_file_name_exp_param')),
type: 'categorical',
};

const LigandColumn: ParameterColumn = {
key: 'ligand_file_name_exp_param',
domain: uniq(map(UseCase1, 'ligand_file_name_exp_param')),
type: 'categorical',
};

const BaseColumn: ParameterColumn = {
key: 'base_file_name_exp_param',
domain: uniq(map(UseCase1, 'base_file_name_exp_param')),
type: 'categorical',
};

return [ArylColumn, BaseColumn, LigandColumn, AdditiveColumn];
}, []);

const [layering, setLayering] = React.useState<string[]>(definitions.map((column) => column.key));
const [aggregation, setAggregation] = React.useState<AggregationType>('max');

const bins = React.useMemo(() => {
return createParameterHierarchy(definitions, UseCase1, layering, [0, 100], (items) => {
return {
value: aggregateBy(aggregation, map(items, 'measured_yield') as number[]),
uncertainty: 0,
};
});
}, [aggregation, definitions, layering]);

const scales = React.useMemo(() => {
const binDomain = d3.extent(Object.values(bins).map((bin) => bin.value.value as number)) as number[];

const squareQuantization = vsup.squareQuantization().n(5).valueDomain(binDomain).uncertaintyDomain([0, 1]);
const squareScale = vsup.scale().quantize(squareQuantization).range(d3.interpolateCividis);

const heatLegend = vsup.legend.heatmapLegend().scale(squareScale).size(150).x(10).y(20);

return {
squareQuantization,
squareScale,
heatLegend,
cutoffDomain: adjustDomain(binDomain),
};
}, [bins]);

const [cutoff, setCutoff] = React.useState<number>(scales.cutoffDomain[0]!);

useStateReset(() => {
setCutoff(scales.cutoffDomain[0]!);
}, scales);

const filter = useCutoffFilter(bins, 'value', cutoff);

const [synchronizeHover, setSynchronizeHover] = React.useState<boolean>(true);

return (
<div>
<FlameTree
bins={bins}
definitions={definitions}
layering={layering}
setLayering={setLayering}
experiments={UseCase1}
filter={filter}
itemHeight={90}
synchronizeHover={synchronizeHover}
colorScale={(item) => {
return scales.squareScale(item.value as number, item.uncertainty as number);
}}
experimentsColorScale={(item) => {
return scales.squareScale(item.measured_yield, 0);
}}
>
<FlameTree.Toolbar>
<Group align="flex-end" gap="xl">
<AggregateSelect label="Value aggregation" aggregation={aggregation} setAggregation={setAggregation} />

<Switch
label="Synchronize hover"
mb={8}
checked={synchronizeHover}
onChange={(event) => {
setSynchronizeHover(event.currentTarget.checked);
}}
/>

<CutoffSlider mb={8} domain={scales.cutoffDomain} value={cutoff} onChange={setCutoff} />
</Group>
</FlameTree.Toolbar>
</FlameTree>
</div>
const columnKeys = React.useMemo(
() => ['aryl_halide_file_name_exp_param', 'additive_file_name_exp_param', 'ligand_file_name_exp_param', 'base_file_name_exp_param'],
[],
);

return <CimeFlameTree dataset={UseCase1} columnKeys={columnKeys} mode="experiment" />;
}
Loading

0 comments on commit 97fd3d0

Please sign in to comment.