Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compact memory data (#101) #127

Closed
wants to merge 11 commits into from
1 change: 0 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ tasks.register<JavaExec>("runLinguaOnConsole") {
dependencies {
implementation("org.jetbrains.kotlin:kotlin-stdlib:1.6.0")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.3.1")
implementation("it.unimi.dsi:fastutil:8.5.6")

testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
testImplementation("org.assertj:assertj-core:3.21.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ class LanguageDetector internal constructor(
for (elem in ngram.rangeOfLowerOrderNgrams()) {
val probability = lookUpNgramProbability(language, elem)
if (probability > 0) {
probabilitiesSum += ln(probability)
probabilitiesSum += ln(probability.toDouble())
break
}
}
Expand All @@ -432,7 +432,7 @@ class LanguageDetector internal constructor(
internal fun lookUpNgramProbability(
language: Language,
ngram: Ngram
): Double {
): Float {
val ngramLength = ngram.value.length
val languageModels = when (ngramLength) {
5 -> fivegramLanguageModels
Expand Down
19 changes: 7 additions & 12 deletions src/main/kotlin/com/github/pemistahl/lingua/internal/Ngram.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder

@JvmInline
@Serializable(with = NgramSerializer::class)
internal data class Ngram(val value: String) : Comparable<Ngram> {
internal value class Ngram(val value: String) : Comparable<Ngram> {
init {
require(value.length in 0..5) {
"length of ngram '$value' is not in range 0..5"
Expand All @@ -33,20 +34,14 @@ internal data class Ngram(val value: String) : Comparable<Ngram> {

override fun toString() = value

override fun compareTo(other: Ngram) = when {
this.value.length > other.value.length -> 1
this.value.length < other.value.length -> -1
else -> 0
}
override fun compareTo(other: Ngram) = value.length.compareTo(other.value.length)

fun rangeOfLowerOrderNgrams() = NgramRange(this, Ngram(this.value[0].toString()))

operator fun dec(): Ngram = when {
this.value.length > 1 -> Ngram(this.value.substring(0, this.value.length - 1))
this.value.length == 1 -> Ngram("")
else -> throw IllegalStateException(
"Zerogram is ngram type of lowest order and can not be decremented"
)
operator fun dec(): Ngram = when (value.length) {
0 -> error("Zerogram is ngram type of lowest order and can not be decremented")
1 -> Ngram("")
else -> Ngram(this.value.substring(0, this.value.length - 1))
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package com.github.pemistahl.lingua.internal

import com.github.pemistahl.lingua.api.Language
import com.github.pemistahl.lingua.internal.util.extension.incrementCounter
import it.unimi.dsi.fastutil.objects.Object2DoubleMap
import it.unimi.dsi.fastutil.objects.Object2DoubleOpenHashMap
import kotlinx.serialization.Serializable
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import java.util.HashMap
import java.util.TreeMap

@Serializable
internal data class JsonLanguageModel(val language: Language, val ngrams: Map<Fraction, String>)
Expand All @@ -32,9 +32,9 @@ internal data class TrainingDataLanguageModel(
val language: Language,
val absoluteFrequencies: Map<Ngram, Int>,
val relativeFrequencies: Map<Ngram, Fraction>,
val jsonRelativeFrequencies: Object2DoubleMap<String>
val jsonRelativeFrequencies: RelativeFrequencies
) {
fun getRelativeFrequency(ngram: Ngram): Double = jsonRelativeFrequencies.getDouble(ngram.value)
fun getRelativeFrequency(ngram: Ngram): Float = jsonRelativeFrequencies[ngram.value]

fun toJson(): String {
val ngrams = mutableMapOf<Fraction, MutableList<Ngram>>()
Expand Down Expand Up @@ -77,29 +77,27 @@ internal data class TrainingDataLanguageModel(
language,
absoluteFrequencies,
relativeFrequencies,
Object2DoubleOpenHashMap()
RelativeFrequencies.build(emptySequence())
)
}

fun fromJson(json: String): TrainingDataLanguageModel {
val jsonLanguageModel = Json.decodeFromString<JsonLanguageModel>(json)
val jsonRelativeFrequencies = Object2DoubleOpenHashMap<String>()

for ((fraction, ngrams) in jsonLanguageModel.ngrams) {
val fractionAsDouble = fraction.toDouble()
for (ngram in ngrams.split(' ')) {
jsonRelativeFrequencies.put(ngram, fractionAsDouble)
val jsonDataSequence = sequence {
for ((fraction, ngrams) in jsonLanguageModel.ngrams) {
val fractionAsFloat = fraction.toFloat()
for (ngram in ngrams.split(' ')) {
yield(ngram to fractionAsFloat)
}
}
}

// Trim to reduce in-memory model size
jsonRelativeFrequencies.trim()

return TrainingDataLanguageModel(
language = jsonLanguageModel.language,
absoluteFrequencies = emptyMap(),
relativeFrequencies = emptyMap(),
jsonRelativeFrequencies = jsonRelativeFrequencies
jsonRelativeFrequencies = RelativeFrequencies.build(jsonDataSequence)
)
}

Expand Down Expand Up @@ -147,4 +145,94 @@ internal data class TrainingDataLanguageModel(
return ngramProbabilities
}
}

internal class RelativeFrequencies private constructor(private val data: Map<Long, Entries>) {
fvasco marked this conversation as resolved.
Show resolved Hide resolved

operator fun get(ngram: String): Float = data[computeHighHash(ngram)]?.get(ngram) ?: 0F

private class Entries(private val chars: ByteArray, private val frequencies: FloatArray) {

val size get() = frequencies.size

operator fun get(ngram: String): Float {
var low = 0
var high = size - 1
while (low <= high) {
if (low + 8 < high) {
// bisection search
val middle = (low + high) / 2
val diff = compareNgram(middle, ngram)
if (diff < 0) low = middle + 1
else if (diff > 0) high = middle - 1
else return frequencies[middle]
} else {
// linear search
for (i in low..high) {
if (compareNgram(i, ngram) == 0) return frequencies[i]
return 0F
}
}
}
return 0F
}

/**
* Compare lower bits only.
*/
private fun compareNgram(pos: Int, ngram: String): Int {
val base = pos * ngram.length
repeat(ngram.length) { i ->
val diff = chars[base + i].compareTo(ngram[i].code.and(0xFF))
if (diff != 0) return diff
}
return 0
}
}

companion object {

/**
* Compare low bits of each character.
* String length must be the same.
*/
private object LowByteComparator : Comparator<String> {
override fun compare(o1: String, o2: String): Int {
for (i in o1.indices) {
val res = o1[i].code.and(0xFF) - o2[i].code.and(0XFF)
fvasco marked this conversation as resolved.
Show resolved Hide resolved
if (res != 0) return res
}
return 0
}
}

internal fun build(relativeFrequencies: Sequence<Pair<String, Float>>): RelativeFrequencies {
val entryMap = LinkedHashMap<Long, MutableMap<String, Float>>()
relativeFrequencies.forEach { (ngram, frequency) ->
val map = entryMap.computeIfAbsent(computeHighHash(ngram)) { TreeMap(LowByteComparator) }
map[ngram] = frequency
}

val data: Map<Long, Entries> = entryMap.entries.associateTo(HashMap()) { (highHash, map) ->
// flatten lower bytes
val chars = map.keys.flatMap { ngram -> ngram.map { (it.code and 0xFF).toByte() } }.toByteArray()
val float = map.values.toFloatArray()
highHash to Entries(chars, float)
}

return RelativeFrequencies(data)
}

/**
* Compute the unique hash of a high bits of each character
* Max ngram supported length: 7.
*/
private fun computeHighHash(ngram: String): Long {
var hash = ngram.length.toLong()
ngram.forEach { c ->
hash = hash.shl(8) or c.code.shr(8).toLong()
}
return hash
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,4 @@

package com.github.pemistahl.lingua.internal.util.extension

internal fun <T> MutableMap<T, Int>.incrementCounter(key: T) {
this[key] = this.getOrDefault(key, 0) + 1
}
internal fun <T> MutableMap<T, Int>.incrementCounter(key: T) = this.merge(key, 1, Int::plus)
Loading