Skip to content

Commit

Permalink
Updating weaviate content retreiever
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaland committed Nov 11, 2024
1 parent 6daa682 commit c20fefa
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 49 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ https://drive.google.com/drive/folders/1jZe0cEw8p_E-fghd6IFPjwiabDNAhtp7?usp=dri

We need to figure out a better way to handel this in the future (or put this on github)

## Tech Debt
## Code Standards

Figure out better secret management
### OWasp Security Scanning

The [OWASP Dependency-Check Plugin](https://owasp.org/www-project-dependency-check/) can be run using the following command:

```sh
mvn validate -P security-scanner
```
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ public ContentRetriever getContentRetriever(RetrieverRequest request) {
.scheme(scheme)
.host(host)
.apiKey(apiKey)
.metadataParentKey("")
.metadataFieldName("")
// .metadataFieldName(null)
.metadataKeys(weaviateRequest.getMetadataFields())
.objectClass(index)
.avoidDups(true)
.textKey(textKey)
.textFieldName(textKey)
.build();


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,15 @@
public class WeaviateEmbeddingStoreCustom implements EmbeddingStore<TextSegment> {

private static final String ADDITIONALS = "_additional";
private static final String METADATA = "_metadata";
private static final String NULL_VALUE = "<null>";

private final WeaviateClient client;
private final String objectClass;
private final boolean avoidDups;
private final String consistencyLevel;
private final String metadataParentKey;
private final String metadataFieldName;
private final Collection<String> metadataKeys;
private final String textKey;
private final String textFieldName;

/**
* Creates a new WeaviateEmbeddingStore instance.
Expand All @@ -64,12 +63,12 @@ public class WeaviateEmbeddingStoreCustom implements EmbeddingStore<TextSegment>
* provided text segment, which avoids duplicated entries in DB.
* If false, then random ID will be generated.
* @param consistencyLevel Consistency level: ONE, QUORUM (default) or ALL. Find more details <a href="https://weaviate.io/developers/weaviate/concepts/replication-architecture/consistency#tunable-write-consistency">here</a>.
* @param metadataParentKey The key in metadata that contains the metadata. Default is "_metadata". If set to empty string, then metadata will be stored in the root of the object.
* @param metadataKeys Metadata keys that should be persisted (optional)
* @param useGrpcForInserts Use GRPC instead of HTTP for batch inserts only. <b>You still need HTTP configured for search</b>
* @param securedGrpc The GRPC connection is secured
* @param grpcPort The port, e.g. 50051. This parameter is optional.
* @param textKey The key in metadata that contains the text. Default is "text".
* @param textFieldName The name of the field that contains the text of a {@link TextSegment}. Default is "text".
* @param metadataFieldName metadataFieldName The name of the field where {@link Metadata} entries are stored. Default is "_metadata". If set to empty string, {@link Metadata} entries will be stored in the root of the Weaviate object.
*/
@Builder
public WeaviateEmbeddingStoreCustom(
Expand All @@ -83,9 +82,9 @@ public WeaviateEmbeddingStoreCustom(
String objectClass,
Boolean avoidDups,
String consistencyLevel,
String metadataParentKey,
Collection<String> metadataKeys,
String textKey
String textFieldName,
String metadataFieldName
) {
try {

Expand All @@ -108,9 +107,9 @@ public WeaviateEmbeddingStoreCustom(
this.objectClass = getOrDefault(objectClass, "Default");
this.avoidDups = getOrDefault(avoidDups, true);
this.consistencyLevel = getOrDefault(consistencyLevel, QUORUM);
this.metadataParentKey = getOrDefault(metadataParentKey, "_metadata");
this.metadataFieldName = getOrDefault(metadataFieldName, "_metadata");
this.metadataKeys = getOrDefault(metadataKeys, Collections.emptyList());
this.textKey = getOrDefault(textKey, "text");
this.textFieldName = getOrDefault(textFieldName, "text");
}

private static String concatenate(String host, Integer port) {
Expand Down Expand Up @@ -187,7 +186,7 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(
double minCertainty
) {
List<Field> fields = new ArrayList<>();
fields.add(Field.builder().name(textKey).build());
fields.add(Field.builder().name(textFieldName).build());
fields.add(Field
.builder()
.name(ADDITIONALS)
Expand All @@ -202,8 +201,8 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(
for (String property : metadataKeys) {
metadataFields.add(Field.builder().name(property).build());
}
if (metadataParentKey != null && !metadataParentKey.isEmpty()) {
fields.add(Field.builder().name(metadataParentKey).fields(metadataFields.toArray(new Field[0])).build());
if (!metadataFieldName.isEmpty()) {
fields.add(Field.builder().name(metadataFieldName).fields(metadataFields.toArray(new Field[0])).build());
} else {
fields.addAll(metadataFields);
}
Expand All @@ -222,7 +221,6 @@ public List<EmbeddingMatch<TextSegment>> findRelevant(
)
.withLimit(maxResults)
.run();

if (result.hasErrors()) {
throw new IllegalArgumentException(
result.getError().getMessages().stream().map(WeaviateErrorMessage::getMessage).collect(joining("\n"))
Expand Down Expand Up @@ -275,7 +273,7 @@ private WeaviateObject buildObject(String id, Embedding embedding, TextSegment s
Map<String, Object> props = new HashMap<>();
Map<String, Object> metadata = prefillMetadata();
if (segment != null) {
props.put(textKey, segment.text());
props.put(textFieldName, segment.text());
if (!segment.metadata().toMap().isEmpty()) {
for (String property : metadataKeys) {
if (segment.metadata().containsKey(property)) {
Expand All @@ -285,7 +283,7 @@ private WeaviateObject buildObject(String id, Embedding embedding, TextSegment s
}
setMetadata(props, metadata);
} else {
props.put(textKey, "");
props.put(textFieldName, "");
setMetadata(props, metadata);
}
props.put("indexFilterable", true);
Expand All @@ -301,8 +299,8 @@ private WeaviateObject buildObject(String id, Embedding embedding, TextSegment s

private void setMetadata(Map<String, Object> props, Map<String, Object> metadata) {
if (metadata != null && !metadata.isEmpty()) {
if(metadataParentKey != null && !metadataParentKey.isEmpty()) {
props.put(metadataParentKey, metadata);
if(metadataFieldName != null && !metadataFieldName.isEmpty()) {
props.put(metadataFieldName, metadata);
} else {
props.putAll(metadata);
}
Expand All @@ -318,33 +316,34 @@ private Map<String, Object> prefillMetadata() {
}

private EmbeddingMatch<TextSegment> toEmbeddingMatch(Map<String, ?> item) {
Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS);
final Metadata metadata = new Metadata();
Map<String, ?> metadataMap = new HashMap();
if (metadataParentKey == null || metadataParentKey.isEmpty()) {
metadataKeys.stream().forEach(key ->
metadata.add(key, item.get(key))
);
}
else if (item.get(metadataParentKey) != null && item.get(metadataParentKey) instanceof Map) {
metadataMap = (Map<String, ?>) item.get(metadataParentKey);
}
if(metadataMap != null) {
for (Map.Entry<String, ?> entry : metadataMap.entrySet()) {
if (entry.getValue() != null && !NULL_VALUE.equals(entry.getValue())) {
metadata.add(entry.getKey(), entry.getValue());
}
}
Map<String, ?> additional = (Map<String, ?>) item.get(ADDITIONALS);
final Metadata metadata = new Metadata();
Map<String, ?> metadataMap = new HashMap<>();
if (metadataFieldName.isEmpty()) {
metadataMap = new HashMap<>(item);
// Remove text field from metadata if we store metadata in the root of the object
metadataMap.remove(textFieldName);
} else if (item.get(metadataFieldName) != null && item.get(metadataFieldName) instanceof Map) {
metadataMap = (Map<String, ?>) item.get(metadataFieldName);
}
if (metadataKeys != null && !metadataKeys.isEmpty()) {
metadataMap.keySet().retainAll(metadataKeys);
}
for (Map.Entry<String, ?> entry : metadataMap.entrySet()) {
if (entry.getValue() != null && !NULL_VALUE.equals(entry.getValue())) {
// TODO: Remove or replace use of deprecated method
metadata.add(entry.getKey(), entry.getValue());
}
String text = (String) item.get(textKey);
}
String text = (String) item.get(textFieldName);

return new EmbeddingMatch<>(
(Double) additional.get("certainty"),
(String) additional.get("id"),
Embedding.from(
((List<Double>) additional.get("vector")).stream().map(Double::floatValue).collect(toList())
),
isNullOrBlank(text) ? null : TextSegment.from(text, metadata)
);
return new EmbeddingMatch<>(
(Double) additional.get("certainty"),
(String) additional.get("id"),
Embedding.from(
((List<Double>) additional.get("vector")).stream().map(Double::floatValue).collect(toList())
),
isNullOrBlank(text) ? null : TextSegment.from(text, metadata)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import java.util.Objects;

import org.apache.commons.lang3.builder.EqualsBuilder;
import org.bson.codecs.pojo.annotations.BsonDiscriminator;

import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.redhat.composer.model.enums.ContentRetrieverType;

public class WeaviateRequest extends BaseRetrieverRequest {
Expand Down

0 comments on commit c20fefa

Please sign in to comment.