Skip to content

Commit 133f9e2

Browse files
authored
feat: add a base model to the repository
1 parent 4c6eee7 commit 133f9e2

File tree

7 files changed

+338
-80
lines changed

7 files changed

+338
-80
lines changed

model/ver20220624/model.json

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
{
2+
"modelTopology": {
3+
"class_name": "Sequential",
4+
"config": {
5+
"name": "sequential_1",
6+
"layers": [
7+
{
8+
"class_name": "Dense",
9+
"config": {
10+
"units": 32,
11+
"activation": "relu",
12+
"use_bias": true,
13+
"kernel_initializer": {
14+
"class_name": "VarianceScaling",
15+
"config": {
16+
"scale": 1,
17+
"mode": "fan_avg",
18+
"distribution": "normal",
19+
"seed": null
20+
}
21+
},
22+
"bias_initializer": { "class_name": "Zeros", "config": {} },
23+
"kernel_regularizer": null,
24+
"bias_regularizer": null,
25+
"activity_regularizer": null,
26+
"kernel_constraint": null,
27+
"bias_constraint": null,
28+
"name": "dense_Dense1",
29+
"trainable": true,
30+
"batch_input_shape": [null, 512],
31+
"dtype": "float32"
32+
}
33+
},
34+
{
35+
"class_name": "BatchNormalization",
36+
"config": {
37+
"axis": -1,
38+
"momentum": 0.99,
39+
"epsilon": 0.001,
40+
"center": true,
41+
"scale": true,
42+
"beta_initializer": { "class_name": "Zeros", "config": {} },
43+
"gamma_initializer": { "class_name": "Ones", "config": {} },
44+
"moving_mean_initializer": { "class_name": "Zeros", "config": {} },
45+
"moving_variance_initializer": {
46+
"class_name": "Ones",
47+
"config": {}
48+
},
49+
"beta_regularizer": null,
50+
"gamma_regularizer": null,
51+
"beta_constraint": null,
52+
"gamma_constraint": null,
53+
"name": "batch_normalization_BatchNormalization1",
54+
"trainable": true
55+
}
56+
},
57+
{
58+
"class_name": "Dense",
59+
"config": {
60+
"units": 32,
61+
"activation": "relu",
62+
"use_bias": true,
63+
"kernel_initializer": {
64+
"class_name": "VarianceScaling",
65+
"config": {
66+
"scale": 1,
67+
"mode": "fan_avg",
68+
"distribution": "normal",
69+
"seed": null
70+
}
71+
},
72+
"bias_initializer": { "class_name": "Zeros", "config": {} },
73+
"kernel_regularizer": null,
74+
"bias_regularizer": null,
75+
"activity_regularizer": null,
76+
"kernel_constraint": null,
77+
"bias_constraint": null,
78+
"name": "dense_Dense2",
79+
"trainable": true
80+
}
81+
},
82+
{
83+
"class_name": "BatchNormalization",
84+
"config": {
85+
"axis": -1,
86+
"momentum": 0.99,
87+
"epsilon": 0.001,
88+
"center": true,
89+
"scale": true,
90+
"beta_initializer": { "class_name": "Zeros", "config": {} },
91+
"gamma_initializer": { "class_name": "Ones", "config": {} },
92+
"moving_mean_initializer": { "class_name": "Zeros", "config": {} },
93+
"moving_variance_initializer": {
94+
"class_name": "Ones",
95+
"config": {}
96+
},
97+
"beta_regularizer": null,
98+
"gamma_regularizer": null,
99+
"beta_constraint": null,
100+
"gamma_constraint": null,
101+
"name": "batch_normalization_BatchNormalization2",
102+
"trainable": true
103+
}
104+
},
105+
{
106+
"class_name": "Dense",
107+
"config": {
108+
"units": 1,
109+
"activation": "sigmoid",
110+
"use_bias": true,
111+
"kernel_initializer": {
112+
"class_name": "VarianceScaling",
113+
"config": {
114+
"scale": 1,
115+
"mode": "fan_avg",
116+
"distribution": "normal",
117+
"seed": null
118+
}
119+
},
120+
"bias_initializer": { "class_name": "Zeros", "config": {} },
121+
"kernel_regularizer": null,
122+
"bias_regularizer": null,
123+
"activity_regularizer": null,
124+
"kernel_constraint": null,
125+
"bias_constraint": null,
126+
"name": "dense_Dense3",
127+
"trainable": true
128+
}
129+
}
130+
]
131+
},
132+
"keras_version": "tfjs-layers 3.18.0",
133+
"backend": "tensor_flow.js"
134+
},
135+
"weightsManifest": [
136+
{
137+
"paths": ["weights.bin"],
138+
"weights": [
139+
{
140+
"name": "dense_Dense1/kernel",
141+
"shape": [512, 32],
142+
"dtype": "float32"
143+
},
144+
{ "name": "dense_Dense1/bias", "shape": [32], "dtype": "float32" },
145+
{
146+
"name": "batch_normalization_BatchNormalization1/gamma",
147+
"shape": [32],
148+
"dtype": "float32"
149+
},
150+
{
151+
"name": "batch_normalization_BatchNormalization1/beta",
152+
"shape": [32],
153+
"dtype": "float32"
154+
},
155+
{
156+
"name": "dense_Dense2/kernel",
157+
"shape": [32, 32],
158+
"dtype": "float32"
159+
},
160+
{ "name": "dense_Dense2/bias", "shape": [32], "dtype": "float32" },
161+
{
162+
"name": "batch_normalization_BatchNormalization2/gamma",
163+
"shape": [32],
164+
"dtype": "float32"
165+
},
166+
{
167+
"name": "batch_normalization_BatchNormalization2/beta",
168+
"shape": [32],
169+
"dtype": "float32"
170+
},
171+
{ "name": "dense_Dense3/kernel", "shape": [32, 1], "dtype": "float32" },
172+
{ "name": "dense_Dense3/bias", "shape": [1], "dtype": "float32" },
173+
{
174+
"name": "batch_normalization_BatchNormalization1/moving_mean",
175+
"shape": [32],
176+
"dtype": "float32"
177+
},
178+
{
179+
"name": "batch_normalization_BatchNormalization1/moving_variance",
180+
"shape": [32],
181+
"dtype": "float32"
182+
},
183+
{
184+
"name": "batch_normalization_BatchNormalization2/moving_mean",
185+
"shape": [32],
186+
"dtype": "float32"
187+
},
188+
{
189+
"name": "batch_normalization_BatchNormalization2/moving_variance",
190+
"shape": [32],
191+
"dtype": "float32"
192+
}
193+
]
194+
}
195+
],
196+
"format": "layers-model",
197+
"generatedBy": "TensorFlow.js tfjs-layers v3.18.0",
198+
"convertedBy": null
199+
}

model/ver20220624/weights.bin

69.4 KB
Binary file not shown.

trainer/datasets.ts

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import * as use from '@tensorflow-models/universal-sentence-encoder'
2+
import * as tf from '@tensorflow/tfjs-node'
3+
import { isObjectLiteralElement, sys } from 'typescript'
4+
5+
async function loadUnsmileData({
6+
filepath,
7+
encoder,
8+
}: {
9+
filepath: string
10+
encoder: use.UniversalSentenceEncoder
11+
}): Promise<tf.data.Dataset<tf.TensorContainer>> {
12+
return tf.data
13+
.csv(filepath, {
14+
delimiter: '\t',
15+
hasHeader: true,
16+
configuredColumnsOnly: true,
17+
columnConfigs: {
18+
clean: {
19+
dtype: 'int32',
20+
isLabel: true,
21+
},
22+
문장: {
23+
dtype: 'string',
24+
},
25+
},
26+
})
27+
.mapAsync(async (data: any) => {
28+
const out = await encoder.embed(data.xs['문장'])
29+
return {
30+
xs: out.flatten(),
31+
ys: Object.values(data.ys),
32+
}
33+
})
34+
.batch(32)
35+
.shuffle(32)
36+
}
37+
38+
export async function loadUnsmileTrainValidData(
39+
encoder: use.UniversalSentenceEncoder,
40+
): Promise<{
41+
trainData: tf.data.Dataset<tf.TensorContainer>
42+
valData: tf.data.Dataset<tf.TensorContainer>
43+
}> {
44+
const trainData = await loadUnsmileData({
45+
filepath: getUnsmileDataUrl('train', 'v1.0'),
46+
encoder,
47+
})
48+
const valData = await loadUnsmileData({
49+
filepath: getUnsmileDataUrl('valid', 'v1.0'),
50+
encoder,
51+
})
52+
return { trainData, valData }
53+
}
54+
55+
/**
56+
*
57+
* @param type "train" or "valid"
58+
* @param version "v1.0"
59+
* @returns full url path
60+
*/
61+
function getUnsmileDataUrl(type: string, version: string): string {
62+
return `https://raw.githubusercontent.com/smilegate-ai/korean_unsmile_dataset/main/unsmile_${type}_${version}.tsv`
63+
}

trainer/model.ts

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import * as tf from '@tensorflow/tfjs-node'
2+
import path from 'path'
3+
4+
const FILE_SCHEME = 'file://'
5+
6+
export async function getModel(
7+
modelDirectoryPath: string,
8+
): Promise<tf.LayersModel | tf.Sequential> {
9+
try {
10+
const modelPath =
11+
FILE_SCHEME +
12+
path.join(modelDirectoryPath.replace(FILE_SCHEME, ''), 'model.json')
13+
console.info(`Trying to load a model from ${modelPath}`)
14+
return await tf.loadLayersModel(modelPath)
15+
} catch (e) {
16+
console.warn(e)
17+
console.warn(`Unable to load a model. Creating a new model`)
18+
return tf.sequential({
19+
layers: [
20+
tf.layers.dense({
21+
inputDim: 512,
22+
units: 32,
23+
activation: 'relu',
24+
}),
25+
tf.layers.batchNormalization(),
26+
tf.layers.dense({
27+
units: 32,
28+
activation: 'relu',
29+
}),
30+
tf.layers.batchNormalization(),
31+
tf.layers.dense({
32+
units: 1,
33+
activation: 'sigmoid',
34+
}),
35+
],
36+
})
37+
}
38+
}

trainer/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"description": "",
55
"main": "index.js",
66
"scripts": {
7-
"build": "npx ts-node trainer.ts"
7+
"start": "ts-node trainer.ts"
88
},
99
"keywords": [],
1010
"author": "",

0 commit comments

Comments
 (0)