Skip to content

Commit

Permalink
Merge pull request #336 from Visual-Intelligence-UMN/chongwei_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
HarryLuUMN authored Sep 6, 2024
2 parents 63e7c7f + 052fe64 commit db74242
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/pages/GraphVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ const GraphVisualizer: React.FC<GraphVisualizerProps> = ({
[],
0,
0,
0,
0
);

Expand Down
13 changes: 9 additions & 4 deletions src/pages/link_classifier/LinkGraphVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ const LinkGraphVisualizer: React.FC<LinkVisualizerProps> = ({
value = intmData.conv2;
}
if (i === 3) {
value = intmData.prob_adj[hubNodeA * hubNodeB + hubNodeA];
value = intmData.prob_adj[hubNodeA * 7126 + hubNodeB];
}
}
console.log("VAW", value, intmData.prob_adj)
Expand All @@ -276,6 +276,7 @@ const LinkGraphVisualizer: React.FC<LinkVisualizerProps> = ({
}
if (value != null && i === 3) {
node.features = [sigmoid(value)]


allValues = allValues.concat(node.features)

Expand Down Expand Up @@ -344,6 +345,7 @@ const LinkGraphVisualizer: React.FC<LinkVisualizerProps> = ({
const centerX = (point1.x + point3.x) / 2;
const centerY = (point1.y + point3.y) / 2;
if (i === 3) {

let bool = "True"
if (value[0] > 0.5) {
bool = "False"
Expand Down Expand Up @@ -391,7 +393,7 @@ const LinkGraphVisualizer: React.FC<LinkVisualizerProps> = ({


let featureCoords = [{ x: 0, y: 0 }, { x: 0, y: 0 }];
if (i == 4) {
if (i === 3) {
featureCoords[0] = { x: centerX, y: centerY };
}
if (i == 5) {
Expand Down Expand Up @@ -453,6 +455,7 @@ const LinkGraphVisualizer: React.FC<LinkVisualizerProps> = ({


if (i === graphs.length - 1) {

connectCrossGraphNodes(
allNodes,
svg,
Expand All @@ -461,13 +464,15 @@ const LinkGraphVisualizer: React.FC<LinkVisualizerProps> = ({
subgraph,
2,
hubNodeA,
hubNodeB
hubNodeB,
centerY

);
svg.selectAll("circle")
.attr("opacity", 0);

if (intmData) {
linkPredFeatureVisualizer(svg, allNodes, offset, height, graphs, 1200, 900, 15, 2, 3, 20, colorSchemes, 2, subgraph, innerComputationMode);
linkPredFeatureVisualizer(svg, allNodes, offset, height, graphs, 1200, 900, 15, 2, 3, 20, colorSchemes, 2, subgraph, innerComputationMode, centerY);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/pages/node_classifier/NodeGraphVisualizer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ const NodeGraphVisualizer: React.FC<NodeGraphVisualizerProps> = ({
[],
1,
0,
0,
0
);

Expand Down
25 changes: 23 additions & 2 deletions src/utils/graphUtils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,7 @@ export function calculationVisualizer(


frame.on("click", function(this: any, event: any) {

d3.selectAll(".weightUnit").lower()
d3.selectAll(".columnGroup").lower()
d3.selectAll(".weightMatrixText").lower()
Expand Down Expand Up @@ -1405,9 +1406,11 @@ export function calculationVisualizer(
const inputVector = featureMap[node.graphIndex][Number(d3.select(this).attr("index"))];
let jthIndexElement = lgIndices[node.id][1];
drawEScoreEquation(lgIndices, eDisplayer, jthIndexElement, start_x - 1200, start_y, usingVectors[1], usingVectors[0], myColor, inputVector, node.graphIndex - 1);
d3.selectAll(".button-group").lower()

d3.selectAll("#my_dataviz").on("click", function(event) {
event.stopPropagation();
d3.selectAll(".button-group").raise()

d3.selectAll(".attn-displayer").remove();
d3.selectAll(".e-displayer").remove()
Expand Down Expand Up @@ -1490,6 +1493,9 @@ export function calculationVisualizer(

paths.push(aggregatedToCalculated);




start_x =
3.5 * offset +
node.relatedNodes[0].features.length * prevRectHeight * 2 +
Expand Down Expand Up @@ -1557,7 +1563,7 @@ export function calculationVisualizer(


// relu
if (!(mode === 0 && node.graphIndex === 3)) {
if (!(mode === 0 && node.graphIndex === 3) && !(mode === 2 && node.graphIndex === 2)) {
const relu = g3.append("g");
let svgPath = "./assets/SVGs/ReLU.svg";
let labelText = "ReLU";
Expand Down Expand Up @@ -1623,6 +1629,8 @@ export function calculationVisualizer(
d3.selectAll(".bias").style("opacity", 1);
d3.selectAll(".softmax").attr("opacity", 0.07);
d3.selectAll(".relu").style("opacity", 1);


d3.selectAll(".output-path").attr("opacity", 1);
d3.selectAll(".softmaxLabel").attr("opacity", 1);
d3.selectAll(".intermediate-path").attr("opacity", 0)
Expand Down Expand Up @@ -1890,6 +1898,17 @@ function weightAnimation(
}

// Pause and replay button

const weightMatrixToBTN = svg
.append("path")
.attr("d", `M${weightsLocation[0][weightsLocation[0].length / 2][0]} ${weightsLocation[0][weightsLocation[0].length / 2][1]} C ${weightsLocation[0][weightsLocation[0].length / 2][0]} ${endCoordList[0][1]}, ${endCoordList[0][0] - 60} ${weightsLocation[0][weightsLocation[0].length / 2][1]} ${endCoordList[0][0] - 60} ${endCoordList[0][1] + 13}`)
.style("stroke", "black")
.attr("class", "to-be-removed procVis")
.style('stroke-width', 1)
.style("fill", "none")
.lower()


const btn = svg.append("g").attr("class", "button-group to-be-removed");


Expand All @@ -1904,6 +1923,7 @@ function weightAnimation(




const gLabel = svg.append("g");
injectSVG(gLabel, endCoordList[0][0] - 80-120-64, endCoordList[0][1] - 22.5-120-64, "./assets/SVGs/interactionHint.svg", "to-be-removed procVis");

Expand All @@ -1914,7 +1934,8 @@ function weightAnimation(
.text("Matrix Multiplication")
.attr("fill", "grey")
.attr("class", "to-be-removed procVis weight-matrix-text")
btn.lower()



btn.on("mouseover", function() {
if (!state.isAnimating) {
Expand Down
31 changes: 13 additions & 18 deletions src/utils/linkPredGraphVisUtil.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ export function linkPredFeatureVisualizer(
colorSchemes:any,
mode: number,
subgraph: any,
innerComputationMode: string
innerComputationMode: string,
centerY: number,
) {
state.isClicked = false;

Expand Down Expand Up @@ -356,18 +357,18 @@ export function linkPredFeatureVisualizer(
rectName = "output";
}
let prevRectHeight = 3;
let groupCentralHeight = currRectHeight * features.length / 2;
let yOffset = groupCentralHeight - (height / 5);
let groupCentralHeight = 175;
let yOffset = groupCentralHeight;

const featureGroup = g2.append("g")
.attr("transform", `translate(${node.x - 7.5}, ${node.y - yOffset})`);
.attr("transform", `translate(${node.x - 7.5}, ${centerY})`);

featureGroup.selectAll("rect")
.data(features)
.enter()
.append("rect")
.attr("x", -10) //adjust x and y coordination so it locates in the middle of the graph
.attr("y", (d: any, i: number) => i * currRectHeight - 190)
.attr("y", (d: any, i: number) => i * currRectHeight)
.attr("width", rectWidth)
.attr("id", (d: any, i: number) => rectName +"-layer-rect-" + i)
.attr("height", currRectHeight)
Expand All @@ -380,7 +381,7 @@ export function linkPredFeatureVisualizer(

const frame = featureGroup.append("rect")
.attr("x", -10)
.attr("y", -190)
.attr("y", 0)
.attr("width", 15)
.attr("class", "node-features")
.attr("height", currRectHeight * (node.features.length))
Expand All @@ -391,14 +392,14 @@ export function linkPredFeatureVisualizer(


const featureGroupCopy = g2.append("g")
.attr("transform", `translate(${node.x - 7.5}, ${node.y - yOffset})`);
.attr("transform", `translate(${node.x - 7.5}, ${centerY})`);

featureGroupCopy.selectAll("rect")
.data(features)
.enter()
.append("rect")
.attr("x", -10) //adjust x and y coordination so it locates in the middle of the graph
.attr("y", (d: any, i: number) => i * currRectHeight - 190)
.attr("y", (d: any, i: number) => i * currRectHeight)
.attr("width", rectWidth)
.attr("height", currRectHeight)
.attr("class", "node-features-Copy")
Expand All @@ -411,7 +412,7 @@ export function linkPredFeatureVisualizer(

const frameCopy = featureGroupCopy.append("rect")
.attr("x", -10)
.attr("y", -190)
.attr("y", 0)
.attr("width", 15)
.attr("class", "node-features-Copy")
.attr("height", currRectHeight * (node.features.length))
Expand All @@ -427,7 +428,7 @@ export function linkPredFeatureVisualizer(

node.featureGroup = featureGroup;
xPos = node.x;
yPos = node.y + rectHeight * node.features.length + yOffset;
yPos = centerY;
let featureGroupLocation: FeatureGroupLocation = {xPos, yPos};
node.featureGroupLocation = featureGroupLocation; // this will be used in calculationvisualizer

Expand Down Expand Up @@ -650,7 +651,7 @@ export function linkPredOutputVisualizer(
.attr("class", "to-be-removed dot-product")
.attr("y", height / 3 - 30)
.text("Sigmoid")
.attr("fill", "black")
.attr("fill", "grey")
.attr("font-size", "17")
.style("opacity", 0)

Expand Down Expand Up @@ -746,13 +747,7 @@ export function linkPredOutputVisualizer(

moveFeaturesBack(node.relatedNodes, originalCoordinates);

node.featureGroup
.transition()
.duration(1000)
.attr(
"transform",
`translate(${node.x - 7.5}, ${node.y + 170 + 5}) rotate(0)`
);


handleClickEvent(originalSvg, node, event, moveOffset, colorSchemes, allNodes, convNum, mode, state)

Expand Down
6 changes: 3 additions & 3 deletions src/utils/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ export function calculateAverage(arr: number[]): number {
return average * 10;
}

export function connectCrossGraphNodes(nodes: any, svg: any, graphs: any[], offset: number, subgraph: any, mode: number, hubNodeA: number, hubNodeB: number) {
export function connectCrossGraphNodes(nodes: any, svg: any, graphs: any[], offset: number, subgraph: any, mode: number, hubNodeA: number, hubNodeB: number, centerY: number) {
const nodesByIndex = d3.group(nodes, (d: any) => d.graphIndex);


Expand Down Expand Up @@ -1356,11 +1356,11 @@ export function connectCrossGraphNodes(nodes: any, svg: any, graphs: any[], offs
const controlX1 = node.x + xOffset1 + (nextNode.x + xOffset2 - node.x - xOffset1) * 0.3;
const controlY1 = node.y + 10;
const controlX2 = node.x + xOffset1 + (nextNode.x + xOffset2 - node.x - xOffset1) * 0.7;
const controlY2 = nextNode.y + 10;
const controlY2 = centerY + 20
if (isValidNode(subgraph, node)) {

const path = svg.append("path")
.attr("d", `M ${node.x + xOffset1} ${node.y + 10} Q ${controlX2} ${controlY2}, ${nextNode.x + xOffset2 - 20} ${nextNode.y + 10}`)
.attr("d", `M ${node.x + xOffset1} ${node.y + 10} Q ${controlX2} ${controlY2}, ${nextNode.x + xOffset2 - 20} ${centerY + 20}`)
.style("stroke", linkStrength(avg))
.style("opacity", 0)
.style('stroke-width', 1)
Expand Down

0 comments on commit db74242

Please sign in to comment.