-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unified interface for all cases, removed unnecessary attributes to im…
…prove loading time of datasets
- Loading branch information
1 parent
7ba84a8
commit 97fd3d0
Showing
10 changed files
with
3,310 additions
and
1,997,601 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import React from 'react'; | ||
|
||
import { Button, Group, Select, Slider, Switch, Text } from '@mantine/core'; | ||
import * as d3 from 'd3v7'; | ||
import { map, uniq } from 'lodash'; | ||
import * as vsup from 'vsup'; | ||
|
||
import { FlameTree, FlameTreeAPI } from '../FlameTree'; | ||
import { AggregateSelect } from '../FlameTree/AggregateSelect'; | ||
import { CutoffSlider } from '../FlameTree/CutoffSlider'; | ||
import { useCutoffFilter, useStateReset } from '../FlameTree/hooks'; | ||
import { AggregationType, ParameterColumn, adjustDomain, aggregateBy, createParameterHierarchy } from '../FlameTree/math'; | ||
|
||
export default function CimeFlameTree({ | ||
dataset, | ||
columnKeys, | ||
mode, | ||
maxIterations, | ||
}: { | ||
dataset: Record<string, unknown>[]; | ||
columnKeys: string[]; | ||
mode: 'experiment' | 'prediction'; | ||
maxIterations?: number; | ||
}) { | ||
const dataKey = mode === 'experiment' ? 'measured_yield' : 'meas_yield'; | ||
|
||
const definitions = React.useMemo(() => { | ||
return columnKeys.map((key) => { | ||
return { | ||
key, | ||
domain: uniq(map(dataset, key)), | ||
type: 'categorical', | ||
} as ParameterColumn; | ||
}); | ||
}, [columnKeys, dataset]); | ||
|
||
const [iteration, setIteration] = React.useState<number>(0); | ||
const [layering, setLayering] = React.useState<string[]>(definitions.map((column) => column.key)); | ||
const [aggregation, setAggregation] = React.useState<AggregationType>('max'); | ||
const [uncertaintyAggregation, setUncertaintyAggregation] = React.useState<AggregationType>('min'); | ||
const [coloring, setColoring] = React.useState<'yield' | 'yield+uncertainty'>('yield+uncertainty'); | ||
|
||
const bins = React.useMemo(() => { | ||
return createParameterHierarchy(definitions, dataset, layering, [0, 100], (items) => { | ||
return { | ||
value: | ||
mode === 'experiment' | ||
? aggregateBy(aggregation, map(items, dataKey) as number[]) | ||
: aggregateBy(aggregation, map(items, `pred_yield_mean_${iteration}`) as number[]), | ||
uncertainty: mode === 'experiment' ? 0 : aggregateBy(uncertaintyAggregation, map(items, `pred_yield_var_${iteration}`) as number[]), | ||
}; | ||
}); | ||
}, [aggregation, dataKey, dataset, definitions, iteration, layering, mode, uncertaintyAggregation]); | ||
|
||
const experiments = React.useMemo(() => { | ||
return dataset.filter((entry) => (entry.experiment_cycle as number) <= iteration && entry[dataKey] !== -1); | ||
}, [dataKey, dataset, iteration]); | ||
|
||
const scales = React.useMemo(() => { | ||
const binDomain = d3.extent(Object.values(bins).map((bin) => bin.value.value as number)) as number[]; | ||
const binVariance = d3.extent(Object.values(bins).map((bin) => bin.value.uncertainty as number)) as number[]; | ||
// const dataDomain = d3.extent(dataset.map((entry) => entry[dataKey] as number)) as number[]; | ||
const dataDomain = [0, 100]; | ||
|
||
let squareQuantization: any; | ||
|
||
if (mode === 'experiment') { | ||
squareQuantization = vsup.squareQuantization().n(5).valueDomain(binDomain).uncertaintyDomain([0, 1]); | ||
} else if (mode === 'prediction') { | ||
const yieldDomain = d3.extent([...binDomain, ...dataDomain]) as number[]; | ||
squareQuantization = vsup.squareQuantization().n(10).valueDomain(yieldDomain).uncertaintyDomain(binVariance); | ||
} | ||
|
||
return { | ||
squareScale: vsup.scale().quantize(squareQuantization).range(d3.interpolateCividis), | ||
cutoffDomain: adjustDomain(binDomain), | ||
}; | ||
}, [bins, mode]); | ||
|
||
const [cutoff, setCutoff] = React.useState<number>(scales.cutoffDomain[0]!); | ||
|
||
useStateReset(() => { | ||
setCutoff(scales.cutoffDomain[0]!); | ||
}, scales); | ||
|
||
const filter = useCutoffFilter(bins, 'value', cutoff); | ||
|
||
const [synchronizeHover, setSynchronizeHover] = React.useState<boolean>(true); | ||
|
||
const apiRef = React.useRef<FlameTreeAPI>(); | ||
|
||
return ( | ||
<div> | ||
<FlameTree | ||
bins={bins} | ||
definitions={definitions} | ||
layering={layering} | ||
setLayering={setLayering} | ||
experiments={experiments} | ||
filter={filter} | ||
itemHeight={90} | ||
apiRef={apiRef} | ||
synchronizeHover={synchronizeHover} | ||
colorScale={(item) => { | ||
if (coloring === 'yield') { | ||
return scales.squareScale(item.value as number, 0); | ||
} | ||
|
||
if (coloring === 'yield+uncertainty') { | ||
return scales.squareScale(item.value as number, item.uncertainty as number); | ||
} | ||
|
||
return 'black'; | ||
}} | ||
experimentsColorScale={(item) => { | ||
return mode === 'experiment' ? scales.squareScale(item[dataKey], 0) : scales.squareScale(item[dataKey], 0); | ||
}} | ||
> | ||
<FlameTree.Toolbar> | ||
<Group align="flex-end" gap="xl"> | ||
<Select | ||
label="Coloring" | ||
value={coloring} | ||
onChange={setColoring as (v: string | null) => void} | ||
data={[ | ||
{ | ||
label: 'Yield', | ||
value: 'yield', | ||
}, | ||
{ | ||
label: 'Yield + Uncertainty', | ||
value: 'yield+uncertainty', | ||
}, | ||
]} | ||
/> | ||
|
||
<AggregateSelect label="Value aggregation" aggregation={aggregation} setAggregation={setAggregation} /> | ||
{mode === 'prediction' ? ( | ||
<AggregateSelect label="Uncertainty aggregation" aggregation={uncertaintyAggregation} setAggregation={setUncertaintyAggregation} /> | ||
) : null} | ||
|
||
{mode === 'prediction' ? ( | ||
<Group mb={8}> | ||
<Text size="sm">Iteration</Text> | ||
<Slider value={iteration} onChange={setIteration} min={0} max={maxIterations} w={200} /> | ||
</Group> | ||
) : null} | ||
|
||
<Switch | ||
label="Synchronize hover" | ||
mb={8} | ||
checked={synchronizeHover} | ||
onChange={(event) => { | ||
setSynchronizeHover(event.currentTarget.checked); | ||
}} | ||
/> | ||
|
||
<CutoffSlider mb={8} domain={scales.cutoffDomain} value={cutoff} onChange={setCutoff} /> | ||
|
||
<Button | ||
onClick={() => { | ||
apiRef.current?.resetZoom(); | ||
}} | ||
> | ||
Reset zoom | ||
</Button> | ||
</Group> | ||
</FlameTree.Toolbar> | ||
</FlameTree> | ||
</div> | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,122 +1,14 @@ | ||
import React from 'react'; | ||
|
||
import { css } from '@emotion/css'; | ||
import { Group, Switch, Text } from '@mantine/core'; | ||
import * as d3 from 'd3v7'; | ||
import { map, uniq } from 'lodash'; | ||
import * as vsup from 'vsup'; | ||
|
||
import { FlameTree } from '../FlameTree'; | ||
import { AggregateSelect } from '../FlameTree/AggregateSelect'; | ||
import { CutoffSlider } from '../FlameTree/CutoffSlider'; | ||
import { useCutoffFilter, useStateReset } from '../FlameTree/hooks'; | ||
import { AggregationType, ParameterColumn, adjustDomain, aggregateBy, createParameterHierarchy } from '../FlameTree/math'; | ||
import { TooltipContent, TooltipContentBin } from '../FlameTree/TooltipContent'; | ||
import CimeFlameTree from './CimeFlameTree'; | ||
|
||
const { UseCase1 } = await import('./case_study_1'); | ||
|
||
export default function FlameCase1() { | ||
const definitions = React.useMemo(() => { | ||
const ArylColumn: ParameterColumn = { | ||
key: 'aryl_halide_file_name_exp_param', | ||
domain: uniq(map(UseCase1, 'aryl_halide_file_name_exp_param')), | ||
type: 'categorical', | ||
}; | ||
|
||
const AdditiveColumn: ParameterColumn = { | ||
key: 'additive_file_name_exp_param', | ||
domain: uniq(map(UseCase1, 'additive_file_name_exp_param')), | ||
type: 'categorical', | ||
}; | ||
|
||
const LigandColumn: ParameterColumn = { | ||
key: 'ligand_file_name_exp_param', | ||
domain: uniq(map(UseCase1, 'ligand_file_name_exp_param')), | ||
type: 'categorical', | ||
}; | ||
|
||
const BaseColumn: ParameterColumn = { | ||
key: 'base_file_name_exp_param', | ||
domain: uniq(map(UseCase1, 'base_file_name_exp_param')), | ||
type: 'categorical', | ||
}; | ||
|
||
return [ArylColumn, BaseColumn, LigandColumn, AdditiveColumn]; | ||
}, []); | ||
|
||
const [layering, setLayering] = React.useState<string[]>(definitions.map((column) => column.key)); | ||
const [aggregation, setAggregation] = React.useState<AggregationType>('max'); | ||
|
||
const bins = React.useMemo(() => { | ||
return createParameterHierarchy(definitions, UseCase1, layering, [0, 100], (items) => { | ||
return { | ||
value: aggregateBy(aggregation, map(items, 'measured_yield') as number[]), | ||
uncertainty: 0, | ||
}; | ||
}); | ||
}, [aggregation, definitions, layering]); | ||
|
||
const scales = React.useMemo(() => { | ||
const binDomain = d3.extent(Object.values(bins).map((bin) => bin.value.value as number)) as number[]; | ||
|
||
const squareQuantization = vsup.squareQuantization().n(5).valueDomain(binDomain).uncertaintyDomain([0, 1]); | ||
const squareScale = vsup.scale().quantize(squareQuantization).range(d3.interpolateCividis); | ||
|
||
const heatLegend = vsup.legend.heatmapLegend().scale(squareScale).size(150).x(10).y(20); | ||
|
||
return { | ||
squareQuantization, | ||
squareScale, | ||
heatLegend, | ||
cutoffDomain: adjustDomain(binDomain), | ||
}; | ||
}, [bins]); | ||
|
||
const [cutoff, setCutoff] = React.useState<number>(scales.cutoffDomain[0]!); | ||
|
||
useStateReset(() => { | ||
setCutoff(scales.cutoffDomain[0]!); | ||
}, scales); | ||
|
||
const filter = useCutoffFilter(bins, 'value', cutoff); | ||
|
||
const [synchronizeHover, setSynchronizeHover] = React.useState<boolean>(true); | ||
|
||
return ( | ||
<div> | ||
<FlameTree | ||
bins={bins} | ||
definitions={definitions} | ||
layering={layering} | ||
setLayering={setLayering} | ||
experiments={UseCase1} | ||
filter={filter} | ||
itemHeight={90} | ||
synchronizeHover={synchronizeHover} | ||
colorScale={(item) => { | ||
return scales.squareScale(item.value as number, item.uncertainty as number); | ||
}} | ||
experimentsColorScale={(item) => { | ||
return scales.squareScale(item.measured_yield, 0); | ||
}} | ||
> | ||
<FlameTree.Toolbar> | ||
<Group align="flex-end" gap="xl"> | ||
<AggregateSelect label="Value aggregation" aggregation={aggregation} setAggregation={setAggregation} /> | ||
|
||
<Switch | ||
label="Synchronize hover" | ||
mb={8} | ||
checked={synchronizeHover} | ||
onChange={(event) => { | ||
setSynchronizeHover(event.currentTarget.checked); | ||
}} | ||
/> | ||
|
||
<CutoffSlider mb={8} domain={scales.cutoffDomain} value={cutoff} onChange={setCutoff} /> | ||
</Group> | ||
</FlameTree.Toolbar> | ||
</FlameTree> | ||
</div> | ||
const columnKeys = React.useMemo( | ||
() => ['aryl_halide_file_name_exp_param', 'additive_file_name_exp_param', 'ligand_file_name_exp_param', 'base_file_name_exp_param'], | ||
[], | ||
); | ||
|
||
return <CimeFlameTree dataset={UseCase1} columnKeys={columnKeys} mode="experiment" />; | ||
} |
Oops, something went wrong.