-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathWord2Vec.scala
403 lines (347 loc) · 14.7 KB
/
Word2Vec.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
// Copyright 2013 trananh
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import java.io._
import scala.Array
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
/** A simple binary file reader.
* @constructor Create a binary file reader.
* @param file The binary file to be read.
*
* @author trananh
*/
class VecBinaryReader(val file: File) {
/** Overloaded constructor */
def this(filename: String) = this(new File(filename))
/** ASCII values for common delimiter characters */
private val SPACE = 32
private val LF = 10
/** Open input streams */
private val fis = new FileInputStream(file)
private val bis = new BufferedInputStream(fis)
private val dis = new DataInputStream(bis)
/** Close the stream. */
def close() { dis.close(); bis.close(); fis.close() }
/** Read the next byte.
* @return The next byte from the file.
*/
def read(): Byte = dis.readByte()
/** Read the next token as a string, using the provided delimiters as breaking points.
* @param delimiters ASCII code of delimiter characters (default to SPACE and LINE-FEED).
* @return String representation of the next token.
*/
def readToken(delimiters: Set[Int] = Set(SPACE, LF)): String = {
val bytes = new ArrayBuffer[Byte]()
val sb = new StringBuilder()
var byte = dis.readByte()
while (!delimiters.contains(byte)) {
bytes.append(byte)
byte = dis.readByte()
}
sb.append(new String(bytes.toArray[Byte])).toString()
}
/** Read next 4 bytes as a floating-point number.
* @return The floating-point value of the next 4 bytes.
*/
def readFloat(): Float = {
// We need to reverse the byte order here due to endian-compatibility.
java.lang.Float.intBitsToFloat(java.lang.Integer.reverseBytes(dis.readInt()))
}
}
/** A Scala port of the word2vec model. This interface allows the user to access the vector representations
* output by the word2vec tool, as well as perform some common operations on those vectors. It does NOT
* implement the actual continuous bag-of-words and skip-gram architectures for computing the vectors.
*
* More information on word2vec can be found here: https://code.google.com/p/word2vec/
*
* Example usage:
* {{{
* val model = new Word2Vec()
* model.load("vectors.bin")
* val results = model.distance(List("france"), N = 10)
*
* model.pprint(results)
* }}}
*
* @constructor Create a word2vec model.
*
* @author trananh
*/
class Word2Vec {
/** Map of words and their associated vector representations */
private val vocab = new mutable.HashMap[String, Array[Float]]()
/** Number of words */
private var numWords = 0
/** Number of floating-point values associated with each word (i.e., length of the vectors) */
private var vecSize = 0
/** Load data from a binary file.
* @param filename Path to file containing word projections in the BINARY FORMAT.
* @param limit Maximum number of words to load from file (a.k.a. max vocab size).
* @param normalize Normalize the loaded vectors if true (default to true).
*/
def load(filename: String, limit: Integer = Int.MaxValue, normalize: Boolean = true): Unit = {
// Check edge case
val file = new File(filename)
if (!file.exists()) {
throw new FileNotFoundException("Binary vector file not found <" + file.toString + ">")
}
// Create new reader to read data
val reader = new VecBinaryReader(file)
// Read header info
numWords = Integer.parseInt(reader.readToken())
vecSize = Integer.parseInt(reader.readToken())
println("\nFile contains " + numWords + " words with vector size " + vecSize)
// Read the vocab words and their associated vector representations
var word = ""
val vector = new Array[Float](vecSize)
var normFactor = 1f
for (_ <- 0 until math.min(numWords, limit)) {
// Read the word
word = reader.readToken()
// Read the vector representation (each vector contains vecSize number of floats)
for (i <- 0 until vector.length) vector(i) = reader.readFloat()
// Store the normalized vector representation, keyed by the word
normFactor = if (normalize) magnitude(vector).toFloat else 1f
vocab.put(word, vector.map(_ / normFactor) )
// Eat up the next delimiter character
reader.read()
}
println("Loaded " + math.min(numWords, limit) + " words.\n")
// Finally, close the reader
reader.close()
}
/** Return the number of words in the vocab.
* @return Number of words in the vocab.
*/
def wordsCount: Int = numWords
/** Size of the vectors.
* @return Size of the vectors.
*/
def vectorSize: Int = vecSize
/** Clear internal data. */
def clear() {
vocab.clear()
numWords = 0
vecSize = 0
}
/** Check if the word is present in the vocab map.
* @param word Word to be checked.
* @return True if the word is in the vocab map.
*/
def contains(word: String): Boolean = {
vocab.get(word).isDefined
}
/** Get the vector representation for the word.
* @param word Word to retrieve vector for.
* @return The vector representation of the word.
*/
def vector(word: String): Array[Float] = {
vocab.getOrElse(word, Array[Float]())
}
/** Compute the Euclidean distance between two vectors.
* @param vec1 The first vector.
* @param vec2 The other vector.
* @return The Euclidean distance between the two vectors.
*/
def euclidean(vec1: Array[Float], vec2: Array[Float]): Double = {
assert(vec1.length == vec2.length, "Uneven vectors!")
var sum = 0.0
for (i <- 0 until vec1.length) sum += math.pow(vec1(i) - vec2(i), 2)
math.sqrt(sum)
}
/** Compute the Euclidean distance between the vector representations of the words.
* @param word1 The first word.
* @param word2 The other word.
* @return The Euclidean distance between the vector representations of the words.
*/
def euclidean(word1: String, word2: String): Double = {
assert(contains(word1) && contains(word2), "Out of dictionary word! " + word1 + " or " + word2)
euclidean(vocab.get(word1).get, vocab.get(word2).get)
}
/** Compute the cosine similarity score between two vectors.
* @param vec1 The first vector.
* @param vec2 The other vector.
* @return The cosine similarity score of the two vectors.
*/
def cosine(vec1: Array[Float], vec2: Array[Float]): Double = {
assert(vec1.length == vec2.length, "Uneven vectors!")
var dot, sum1, sum2 = 0.0
for (i <- 0 until vec1.length) {
dot += (vec1(i) * vec2(i))
sum1 += (vec1(i) * vec1(i))
sum2 += (vec2(i) * vec2(i))
}
dot / (math.sqrt(sum1) * math.sqrt(sum2))
}
/** Compute the cosine similarity score between the vector representations of the words.
* @param word1 The first word.
* @param word2 The other word.
* @return The cosine similarity score between the vector representations of the words.
*/
def cosine(word1: String, word2: String): Double = {
assert(contains(word1) && contains(word2), "Out of dictionary word! " + word1 + " or " + word2)
cosine(vocab.get(word1).get, vocab.get(word2).get)
}
/** Compute the magnitude of the vector.
* @param vec The vector.
* @return The magnitude of the vector.
*/
def magnitude(vec: Array[Float]): Double = {
math.sqrt(vec.foldLeft(0.0){(sum, x) => sum + (x * x)})
}
/** Normalize the vector.
* @param vec The vector.
* @return A normalized vector.
*/
def normalize(vec: Array[Float]): Array[Float] = {
val mag = magnitude(vec).toFloat
vec.map(_ / mag)
}
/** Find the vector representation for the given list of word(s) by aggregating (summing) the
* vector for each word.
* @param input The input word(s).
* @return The sum vector (aggregated from the input vectors).
*/
def sumVector(input: List[String]): Array[Float] = {
// Find the vector representation for the input. If multiple words, then aggregate (sum) their vectors.
input.foreach(w => assert(contains(w), "Out of dictionary word! " + w))
val vector = new Array[Float](vecSize)
input.foreach(w => for (j <- 0 until vector.length) vector(j) += vocab.get(w).get(j))
vector
}
/** Find N closest terms in the vocab to the given vector, using only words from the in-set (if defined)
* and excluding all words from the out-set (if non-empty). Although you can, it doesn't make much
* sense to define both in and out sets.
* @param vector The vector.
* @param inSet Set of words to consider. Specify None to use all words in the vocab (default behavior).
* @param outSet Set of words to exclude (default to empty).
* @param N The maximum number of terms to return (default to 40).
* @return The N closest terms in the vocab to the given vector and their associated cosine similarity scores.
*/
def nearestNeighbors(vector: Array[Float], inSet: Option[Set[String]] = None,
outSet: Set[String] = Set[String](), N: Integer = 40)
: List[(String, Float)] = {
// For performance efficiency, we maintain the top/closest terms using a priority queue.
// Note: We invert the distance here because a priority queue will dequeue the highest priority element,
// but we would like it to dequeue the lowest scoring element instead.
val top = new mutable.PriorityQueue[(String, Float)]()(Ordering.by(-_._2))
// Iterate over each token in the vocab and compute its cosine score to the input.
var dist = 0f
val iterator = if (inSet.isDefined) vocab.filterKeys(k => inSet.get.contains(k)).iterator else vocab.iterator
iterator.foreach(entry => {
// Skip tokens in the out set
if (!outSet.contains(entry._1)) {
dist = cosine(vector, entry._2).toFloat
if (top.size < N || top.head._2 < dist) {
top.enqueue((entry._1, dist))
if (top.length > N) {
// If the queue contains over N elements, then dequeue the highest priority element
// (which will be the element with the lowest cosine score).
top.dequeue()
}
}
}
})
// Return the top N results as a sorted list.
assert(top.length <= N)
top.toList.sortWith(_._2 > _._2)
}
/** Find the N closest terms in the vocab to the input word(s).
* @param input The input word(s).
* @param N The maximum number of terms to return (default to 40).
* @return The N closest terms in the vocab to the input word(s) and their associated cosine similarity scores.
*/
def distance(input: List[String], N: Integer = 40): List[(String, Float)] = {
// Check for edge cases
if (input.size == 0) return List[(String, Float)]()
input.foreach(w => {
if (!contains(w)) {
println("Out of dictionary word! " + w)
return List[(String, Float)]()
}
})
// Find the vector representation for the input. If multiple words, then aggregate (sum) their vectors.
val vector = sumVector(input)
nearestNeighbors(normalize(vector), outSet = input.toSet, N = N)
}
/** Find the N closest terms in the vocab to the analogy:
* - [word1] is to [word2] as [word3] is to ???
*
* The algorithm operates as follow:
* - Find a vector approximation of the missing word = vec([word2]) - vec([word1]) + vec([word3]).
* - Return words closest to the approximated vector.
*
* @param word1 First word in the analogy [word1] is to [word2] as [word3] is to ???.
* @param word2 Second word in the analogy [word1] is to [word2] as [word3] is to ???
* @param word3 Third word in the analogy [word1] is to [word2] as [word3] is to ???.
* @param N The maximum number of terms to return (default to 40).
*
* @return The N closest terms in the vocab to the analogy and their associated cosine similarity scores.
*/
def analogy(word1: String, word2: String, word3: String, N: Integer = 40): List[(String, Float)] = {
// Check for edge cases
if (!contains(word1) || !contains(word2) || !contains(word3)) {
println("Out of dictionary word! " + Array(word1, word2, word3).mkString(" or "))
return List[(String, Float)]()
}
// Find the vector approximation for the missing analogy.
val vector = new Array[Float](vecSize)
for (j <- 0 until vector.length)
vector(j) = vocab.get(word2).get(j) - vocab.get(word1).get(j) + vocab.get(word3).get(j)
nearestNeighbors(normalize(vector), outSet = Set(word1, word2, word3), N = N)
}
/** Rank a set of words by their respective distance to some central term.
* @param word The central word.
* @param set Set of words to rank.
* @return Ordered list of words and their associated scores.
*/
def rank(word: String, set: Set[String]): List[(String, Float)] = {
// Check for edge cases
if (set.size == 0) return List[(String, Float)]()
(set + word).foreach(w => {
if (!contains(w)) {
println("Out of dictionary word! " + w)
return List[(String, Float)]()
}
})
nearestNeighbors(vocab.get(word).get, inSet = Option(set), N = set.size)
}
/** Pretty print the list of words and their associated scores.
* @param words List of (word, score) pairs to be printed.
*/
def pprint(words: List[(String, Float)]) = {
println("\n%50s".format("Word") + (" " * 7) + "Cosine distance\n" + ("-" * 72))
println(words.map(s => "%50s".format(s._1) + (" " * 7) + "%15f".format(s._2)).mkString("\n"))
}
}
/** ********************************************************************************
* Demo of the Scala ported word2vec model.
* ********************************************************************************
*/
object RunWord2Vec {
/** Demo. */
def main(args: Array[String]) {
// Load word2vec model from binary file.
val model = new Word2Vec()
model.load("../word2vec-scala/vectors.bin")
// distance: Find N closest words
model.pprint(model.distance(List("france"), N = 10))
model.pprint(model.distance(List("france", "usa")))
model.pprint(model.distance(List("france", "usa", "usa")))
// analogy: "king" is to "queen", as "man" is to ?
model.pprint(model.analogy("king", "queen", "man", N = 10))
// rank: Rank a set of words by their respective distance to the central term
model.pprint(model.rank("apple", Set("orange", "soda", "lettuce")))
}
}