-
Notifications
You must be signed in to change notification settings - Fork 11
/
index.html
152 lines (133 loc) · 6.78 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
<!DOCTYPE html>
<html>
<header>
<title>Segment Anything Frontend Only Demo</title>
</header>
<body>
<h1>Segement Anything in the Browser</h1>
<div>
<input title="Image from File" type="file" id="file-in" name="file-in">
</div>
<div style="display: none;">
<img id="original-image" src="#" />
</div>
<h3>Status</h3>
<span id="status">No image uploaded</span>
<h3>Resized Image and Mask</h3>
<canvas id="canvas"></canvas>
<!-- import ONNXRuntime Web from CDN -->
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>
<script>
var canvas = document.getElementById('canvas');
const dimension = 1024;
var image_embeddings;
var imageImageData
async function handleClick(event) {
const rect = canvas.getBoundingClientRect();
const x = event.clientX - rect.left;
const y = event.clientY - rect.top;
console.log('Clicked position:', x, y);
document.getElementById("status").textContent = `Clicked on (${x}, ${y}). Downloading the decoder model if needed and generating mask...`;
let context = canvas.getContext('2d');
context.clearRect(0, 0, canvas.width, canvas.height);
canvas.width = imageImageData.width;
canvas.height = imageImageData.height;
context.putImageData(imageImageData, 0, 0);
context.fillStyle = 'green';
context.fillRect(x, y, 10, 10);
const pointCoords = new ort.Tensor(new Float32Array([x, y, 0, 0]), [1, 2, 2]);
const pointLabels = new ort.Tensor(new Float32Array([0, -1]), [1, 2]);
// generate a (1, 1, 256, 256) tensor with all values set to 0
const maskInput = new ort.Tensor(new Float32Array(256 * 256), [1, 1, 256, 256]);
const hasMask = new ort.Tensor(new Float32Array([0]), [1,]);
const origianlImageSize = new ort.Tensor(new Float32Array([684, 1024]), [2,]);
const decodingSession = await ort.InferenceSession.create('./models/sam_vit_b_01ec64.decoder.onnx');
console.log("Decoder session", decodingSession);
const decodingFeeds = {
"image_embeddings": image_embeddings,
"point_coords": pointCoords,
"point_labels": pointLabels,
"mask_input": maskInput,
"has_mask_input": hasMask,
"orig_im_size": origianlImageSize
}
console.log("generating mask...");
start = Date.now();
try {
results = await decodingSession.run(decodingFeeds);
console.log("Generated mask:", results);
const mask = results.masks;
const maskImageData = mask.toImageData();
context.globalAlpha = 0.5;
// convert image data to image bitmap
let imageBitmap = await createImageBitmap(maskImageData);
context.drawImage(imageBitmap, 0, 0);
} catch (error) {
console.log(`caught error: ${error}`)
}
end = Date.now();
console.log(`generating masks took ${(end - start)/1000} seconds`);
document.getElementById("status").textContent = `Mask generated. Click on the image to generate new mask.`;
}
async function handleImage(img) {
console.log(status);
document.getElementById("status").textContent = `Uploaded image of size ${img.width}x${img.height}. Downloading the encoder model (~100 MB) if not cached and generating embedding. this will take a minute...`;
console.log(`Uploaded image of size ${img.width}x${img.height}`);
const scaleX = dimension / img.width;
const scaleY = dimension / img.height;
const resizedTensor = await ort.Tensor.fromImage(img, options = { resizedWidth: 1024, resizedHeight: 684 });
const resizeImage = resizedTensor.toImageData();
let imageDataTensor = await ort.Tensor.fromImage(resizeImage);
imageImageData = imageDataTensor.toImageData();
console.log("image data tensor:", imageDataTensor);
// Presenting the images on dom
canvas.width = imageImageData.width;
canvas.height = imageImageData.height;
let context = canvas.getContext('2d');
context.putImageData(imageImageData, 0, 0);
let tf_tensor = tf.tensor(imageDataTensor.data, imageDataTensor.dims);
tf_tensor = tf_tensor.reshape([3, 684, 1024]);
tf_tensor = tf_tensor.transpose([1, 2, 0]).mul(255);
imageDataTensor = new ort.Tensor(tf_tensor.dataSync(), tf_tensor.shape);
const session = await ort.InferenceSession.create('https://files.sunu.in/sam_vit_b_01ec64.encoder.preprocess.quant.onnx');
console.log("Encoder Session", session);
const feeds = { "input_image": imageDataTensor }
console.log("Computing image embedding; this will take a minute...")
let start = Date.now();
let results;
try {
results = await session.run(feeds);
console.log("Encoding result:", results);
image_embeddings = results.image_embeddings;
} catch (error) {
console.log(`caught error: ${error}`)
document.getElementById("status").textContent = `Error: ${error}`;
}
let end = Date.now();
let time_taken = (end - start) / 1000;
console.log(`Computing image embedding took ${time_taken} seconds`);
document.getElementById("status").textContent = `Embedding generated in ${time_taken} seconds. Click on the image to generate mask.`;
canvas.addEventListener('click', handleClick);
}
function loadImage(fileReader) {
var img = document.getElementById("original-image");
img.onload = () => handleImage(img);
img.src = fileReader.result;
}
// use an async context to call onnxruntime functions.
async function main() {
var img = document.getElementById("original-image");
document.getElementById("file-in").onchange = function (evt) {
let target = evt.target || window.event.src, files = target.files;
if (FileReader && files && files.length) {
var fileReader = new FileReader();
fileReader.onload = () => loadImage(fileReader);
fileReader.readAsDataURL(files[0]);
}
};
}
main();
</script>
</body>
</html>