Skip to content

Commit

Permalink
Make FSTCompiler.compile() to only return the FSTMetadata (apache#12831)
Browse files Browse the repository at this point in the history
* Make FSTCompiler.compile() to only return the FSTMetadata

* tidy code
  • Loading branch information
dungba88 authored Feb 5, 2024
1 parent c02f547 commit 63d4ba9
Show file tree
Hide file tree
Showing 35 changed files with 149 additions and 90 deletions.
4 changes: 4 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ API Changes
* GITHUB#12875: Ensure token position is always increased in PathHierarchyTokenizer and ReversePathHierarchyTokenizer
and resulting tokens do not overlap. (Michael Froh, Lukáš Vlček)

* GITHUB#12624, GITHUB#12831: Allow FSTCompiler to stream to any DataOutput while building, and
make compile() only return the FSTMetadata. For on-heap (default) use case, please use
FST.fromFSTReader(fstMetadata, fstCompiler.getFSTReader()) to create the FST. (Anh Dung Bui)

New Features
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public NormalizeCharMap build() {
for (Map.Entry<String, String> ent : pendingPairs.entrySet()) {
fstCompiler.add(Util.toUTF16(ent.getKey(), scratch), new CharsRef(ent.getValue()));
}
map = fstCompiler.compile();
map = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
pendingPairs.clear();
} catch (IOException ioe) {
// Bogus FST IOExceptions!! (will never happen)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ConvTable {
fstCompiler.add(scratchInts.get(), new CharsRef(entry.getValue()));
}

fst = fstCompiler.compile();
fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
} catch (IOException bogus) {
throw new RuntimeException(bogus);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ private FST<IntsRef> affixFST(TreeMap<String, List<Integer>> affixes) throws IOE
}
fstCompiler.add(scratch.get(), output);
}
return fstCompiler.compile();
return FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ public StemmerOverrideMap build() throws IOException {
intsSpare.copyUTF8Bytes(bytesRef);
fstCompiler.add(intsSpare.get(), new BytesRef(outputValues.get(id)));
}
return new StemmerOverrideMap(fstCompiler.compile(), ignoreCase);
return new StemmerOverrideMap(
FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader()), ignoreCase);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public SynonymMap build() throws IOException {
fstCompiler.add(Util.toUTF32(input, scratchIntsRef), scratch.toBytesRef());
}

FST<BytesRef> fst = fstCompiler.compile();
FST<BytesRef> fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
return new SynonymMap(fst, words, maxHorizontalContext);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private TokenInfoDictionaryWriter buildDictionary(List<Path> csvFiles) throws IO
dictionary.addMapping((int) ord, offset);
offset = next;
}
dictionary.setFST(fstCompiler.compile());
dictionary.setFST(FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader()));
return dictionary;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ public int compare(String[] left, String[] right) {
segmentations.add(wordIdAndLength);
ord++;
}
this.fst = new TokenInfoFST(fstCompiler.compile(), false);
this.fst =
new TokenInfoFST(
FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader()), false);
this.morphAtts = new UserMorphData(data.toArray(new String[0]));
this.segmentations = segmentations.toArray(new int[segmentations.size()][]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private TokenInfoDictionaryWriter buildDictionary(List<Path> csvFiles) throws IO
dictionary.addMapping((int) ord, offset);
offset = next;
}
dictionary.setFST(fstCompiler.compile());
dictionary.setFST(FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader()));
return dictionary;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ private UserDictionary(List<String> entries) throws IOException {
lastToken = token;
ord++;
}
this.fst = new TokenInfoFST(fstCompiler.compile());
this.fst =
new TokenInfoFST(FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader()));
int[][] segmentations = _segmentations.toArray(new int[_segmentations.size()][]);
short[] rightIds = new short[_rightIds.size()];
for (int i = 0; i < _rightIds.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ public void compileIndex(
}
}

index = fstCompiler.compile();
index = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());

assert subIndices == null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ private void updateFST(SortedMap<String, Double> weights) throws IOException {
fstCompiler.add(
Util.toIntsRef(scratchBytes.get(), scratchInts), entry.getValue().longValue());
}
fst = fstCompiler.compile();
fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ public void add(BytesRef text, TermStats stats, long termsFilePointer) throws IO

@Override
public void finish(long termsFilePointer) throws IOException {
fst = fstCompiler.compile();
fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
if (fst != null) {
fst.save(out, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ public void compileIndex(

assert sumTotalTermCount == totFloorTermCount;

index = fstCompiler.compile();
index = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
assert subIndices == null;

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ public void finishTerm(BytesRef text, BlockTermState state) throws IOException {
public void finish(long sumTotalTermFreq, long sumDocFreq, int docCount) throws IOException {
// save FST dict
if (numTerms > 0) {
final FST<FSTTermOutputs.TermData> fst = fstCompiler.compile();
final FST<FSTTermOutputs.TermData> fst =
FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
fields.add(
new FieldMetaData(fieldInfo, numTerms, sumTotalTermFreq, sumDocFreq, docCount, fst));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ private void loadTerms() throws IOException {
}
}
docCount = visitedDocs.cardinality();
fst = fstCompiler.compile();
fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
/*
PrintStream ps = new PrintStream("out.dot");
fst.toDot(ps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ public void add(BytesRef blockKey, long blockFilePointer) throws IOException {

@Override
public FSTDictionary build() throws IOException {
return new FSTDictionary(fstCompiler.compile());
return new FSTDictionary(
FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ public void compileIndex(
}
}

index = fstCompiler.compile();
index = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());

assert subIndices == null;

Expand Down
19 changes: 16 additions & 3 deletions lucene/core/src/java/org/apache/lucene/util/fst/FST.java
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,21 @@ public FST(FSTMetadata<T> metadata, DataInput in, FSTStore fstStore) throws IOEx
this.fstReader = fstReader;
}

/**
* Create a FST from a {@link FSTReader}. Return null if the metadata is null.
*
* @param fstMetadata the metadata
* @param fstReader the FSTReader
* @return the FST
*/
public static <T> FST<T> fromFSTReader(FSTMetadata<T> fstMetadata, FSTReader fstReader) {
// FSTMetadata could be null if there is no node accepted by the FST
if (fstMetadata == null) {
return null;
}
return new FST<>(fstMetadata, Objects.requireNonNull(fstReader, "FSTReader cannot be null"));
}

/**
* Read the FST metadata from DataInput
*
Expand Down Expand Up @@ -516,9 +531,7 @@ public FSTMetadata<T> getMetadata() {
}

/**
* Save the FST to DataOutput. If you use an {@link org.apache.lucene.store.IndexOutput} to build
* the FST, then you should not and do not need to call this method, as the FST is already saved.
* Doing so will throw an {@link UnsupportedOperationException}.
* Save the FST to DataOutput.
*
* @param metaOut the DataOutput to write the metadata to
* @param out the DataOutput to write the FST bytes to
Expand Down
62 changes: 44 additions & 18 deletions lucene/core/src/java/org/apache/lucene/util/fst/FSTCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public class FSTCompiler<T> {
private static final FSTReader NULL_FST_READER = new NullFSTReader();

private final NodeHash<T> dedupHash;
// a temporary FST used during building for NodeHash cache
final FST<T> fst;
private final T NO_OUTPUT;

Expand Down Expand Up @@ -173,9 +174,7 @@ private FSTCompiler(
paddingBytePending = true;
this.dataOutput = dataOutput;
fst =
new FST<>(
new FST.FSTMetadata<>(inputType, outputs, null, -1, version, 0),
toFSTReader(dataOutput));
new FST<>(new FST.FSTMetadata<>(inputType, outputs, null, -1, version, 0), NULL_FST_READER);
if (suffixRAMLimitMB < 0) {
throw new IllegalArgumentException("ramLimitMB must be >= 0; got: " + suffixRAMLimitMB);
} else if (suffixRAMLimitMB > 0) {
Expand All @@ -193,16 +192,6 @@ private FSTCompiler(
}
}

// Get the respective FSTReader of the DataOutput. If the DataOutput is also a FSTReader then we
// will use it, otherwise we will return a NullFSTReader. Attempting to read from a FST with
// NullFSTReader will throw UnsupportedOperationException
private FSTReader toFSTReader(DataOutput dataOutput) {
if (dataOutput instanceof FSTReader) {
return (FSTReader) dataOutput;
}
return NULL_FST_READER;
}

/**
* This class is used for FST backed by non-FSTReader DataOutput. It does not allow getting the
* reverse BytesReader nor writing to a DataOutput.
Expand All @@ -227,6 +216,22 @@ public void writeTo(DataOutput out) {
}
}

/**
* Get the respective {@link FSTReader} of the {@link DataOutput}. To call this method, you need
* to use the default DataOutput or {@link #getOnHeapReaderWriter(int)}, otherwise we will throw
* an exception.
*
* @return the DataOutput as FSTReader
* @throws IllegalStateException if the DataOutput does not implement FSTReader
*/
public FSTReader getFSTReader() {
if (dataOutput instanceof FSTReader) {
return (FSTReader) dataOutput;
}
throw new IllegalStateException(
"The DataOutput must implement FSTReader, but got " + dataOutput);
}

/**
* Fluent-style constructor for FST {@link FSTCompiler}.
*
Expand Down Expand Up @@ -967,10 +972,31 @@ private boolean validOutput(T output) {
return output == NO_OUTPUT || !output.equals(NO_OUTPUT);
}

/** Returns final FST. NOTE: this will return null if nothing is accepted by the FST. */
// TODO: make this method to only return the FSTMetadata and user needs to construct the FST
// themselves
public FST<T> compile() throws IOException {
/**
* Returns the metadata of the final FST. NOTE: this will return null if nothing is accepted by
* the FST themselves.
*
* <p>To create the FST, you need to:
*
* <p>- If a FSTReader DataOutput was used, such as the one returned by {@link
* #getOnHeapReaderWriter(int)}
*
* <pre class="prettyprint">
* fstMetadata = fstCompiler.compile();
* fst = FST.fromFSTReader(fstMetadata, fstCompiler.getFSTReader());
* </pre>
*
* <p>- If a non-FSTReader DataOutput was used, such as {@link
* org.apache.lucene.store.IndexOutput}, you need to first create the corresponding {@link
* org.apache.lucene.store.DataInput}, such as {@link org.apache.lucene.store.IndexInput} then
* pass it to the FST construct
*
* <pre class="prettyprint">
* fstMetadata = fstCompiler.compile();
* fst = new FST&lt;&gt;(fstMetadata, dataInput, new OffHeapFSTStore());
* </pre>
*/
public FST.FSTMetadata<T> compile() throws IOException {

final UnCompiledNode<T> root = frontier[0];

Expand All @@ -990,7 +1016,7 @@ public FST<T> compile() throws IOException {
// root.output=" + root.output);
finish(compileNode(root).node);

return fst;
return fst.metadata;
}

/** Expert: holds a pending (seen but not yet serialized) arc. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
* scratchBytes.copyChars(inputValues[i]);
* fstCompiler.add(Util.toIntsRef(scratchBytes.toBytesRef(), scratchInts), outputValues[i]);
* }
* FST&lt;Long&gt; fst = fstCompiler.compile();
* FST&lt;Long&gt; fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
* </pre>
*
* Retrieval by key:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void test() throws Exception {
nextInput(r, ints2);
}

FST<Object> fst = fstCompiler.compile();
FST<Object> fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());

for (int verify = 0; verify < 2; verify++) {
System.out.println(
Expand Down Expand Up @@ -183,7 +183,7 @@ public void test() throws Exception {
nextInput(r, ints);
}

FST<BytesRef> fst = fstCompiler.compile();
FST<BytesRef> fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
for (int verify = 0; verify < 2; verify++) {

System.out.println(
Expand Down Expand Up @@ -273,7 +273,7 @@ public void test() throws Exception {
nextInput(r, ints);
}

FST<Long> fst = fstCompiler.compile();
FST<Long> fst = FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());

for (int verify = 0; verify < 2; verify++) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ public void test() throws Exception {
nextInput(r, ints2);
}

FST<Object> fst = fstCompiler.compile();
FST.FSTMetadata<Object> fstMetadata = fstCompiler.compile();
indexOutput.close();
try (IndexInput indexInput = dir.openInput("fst", IOContext.DEFAULT)) {
fst = new FST<>(fst.getMetadata(), indexInput, new OffHeapFSTStore());
FST<Object> fst = new FST<>(fstMetadata, indexInput, new OffHeapFSTStore());

for (int verify = 0; verify < 2; verify++) {
System.out.println(
Expand Down Expand Up @@ -180,10 +180,10 @@ public void test() throws Exception {
nextInput(r, ints);
}

FST<BytesRef> fst = fstCompiler.compile();
FST.FSTMetadata<BytesRef> fstMetadata = fstCompiler.compile();
indexOutput.close();
try (IndexInput indexInput = dir.openInput("fst", IOContext.DEFAULT)) {
fst = new FST<>(fst.getMetadata(), indexInput, new OffHeapFSTStore());
FST<BytesRef> fst = new FST<>(fstMetadata, indexInput, new OffHeapFSTStore());
for (int verify = 0; verify < 2; verify++) {

System.out.println(
Expand Down Expand Up @@ -265,10 +265,10 @@ public void test() throws Exception {
nextInput(r, ints);
}

FST<Long> fst = fstCompiler.compile();
FST.FSTMetadata<Long> fstMetadata = fstCompiler.compile();
indexOutput.close();
try (IndexInput indexInput = dir.openInput("fst", IOContext.DEFAULT)) {
fst = new FST<>(fst.getMetadata(), indexInput, new OffHeapFSTStore());
FST<Long> fst = new FST<>(fstMetadata, indexInput, new OffHeapFSTStore());

for (int verify = 0; verify < 2; verify++) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ private static FST<Object> buildFST(List<BytesRef> entries, FSTCompiler<Object>
}
last = entry;
}
return fstCompiler.compile();
return FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
}

public static void main(String... args) throws Exception {
Expand Down Expand Up @@ -333,7 +333,7 @@ private static FST<CharsRef> recompile(FST<CharsRef> fst, float oversizingFactor
while ((inputOutput = fstEnum.next()) != null) {
fstCompiler.add(inputOutput.input, CharsRef.deepCopyOf(inputOutput.output));
}
return fstCompiler.compile();
return FST.fromFSTReader(fstCompiler.compile(), fstCompiler.getFSTReader());
}

private static int walk(FST<CharsRef> read) throws IOException {
Expand Down
Loading

0 comments on commit 63d4ba9

Please sign in to comment.