Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to normalization. Global normaliztion was being applied, when ele… #324

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions adalflow/adalflow/components/retriever/faiss_retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Semantic search/embedding-based retriever using FAISS."""

import faiss
from typing import (
List,
Optional,
Expand All @@ -24,22 +25,23 @@
RetrieverStrQueryType,
EmbedderOutputType,
)
from adalflow.core.functional import normalize_np_array, is_normalized
from adalflow.core.functional import normalize_embeddings, is_normalized

from adalflow.utils.lazy_import import safe_import, OptionalPackages

safe_import(OptionalPackages.FAISS.value[0], OptionalPackages.FAISS.value[1])
import faiss

log = logging.getLogger(__name__)

FAISSRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray] # single embedding
# single embedding
FAISSRetrieverDocumentEmbeddingType = Union[List[float], np.ndarray]
FAISSRetrieverDocumentsType = Sequence[FAISSRetrieverDocumentEmbeddingType]

FAISSRetrieverEmbeddingQueryType = Union[
List[float], List[List[float]], np.ndarray
] # single embedding or list of embeddings
FAISSRetrieverQueryType = Union[RetrieverStrQueryType, FAISSRetrieverEmbeddingQueryType]
FAISSRetrieverQueryType = Union[RetrieverStrQueryType,
FAISSRetrieverEmbeddingQueryType]
FAISSRetrieverQueriesType = Sequence[FAISSRetrieverQueryType]
FAISSRetrieverQueriesStrType = Sequence[RetrieverStrQueryType]
FAISSRetrieverQueriesEmbeddingType = Sequence[FAISSRetrieverEmbeddingQueryType]
Expand Down Expand Up @@ -161,7 +163,8 @@ def build_index_from_documents(
If you are using Document format, pass them as [doc.vector for doc in documents]
"""
if document_map_func:
assert callable(document_map_func), "document_map_func should be callable"
assert callable(
document_map_func), "document_map_func should be callable"
documents = [document_map_func(doc) for doc in documents]
try:
self.documents = documents
Expand All @@ -183,7 +186,7 @@ def build_index_from_documents(
log.warning(
"Embeddings are not normalized, normalizing the embeddings"
)
self.xb = normalize_np_array(self.xb)
self.xb = normalize_embeddings(self.xb)

self._preprare_faiss_index_from_np_array(self.xb)
log.info(f"Index built with {self.total_documents} chunks")
Expand Down Expand Up @@ -295,7 +298,8 @@ def retrieve_string_queries(
output: RetrieverOutputType = [
RetrieverOutput(doc_indices=[], query=query) for query in queries
]
retrieved_output: RetrieverOutputType = self._to_retriever_output(Ind, D)
retrieved_output: RetrieverOutputType = self._to_retriever_output(
Ind, D)

# fill in the doc_indices and score for valid queries
for i, per_query_output in enumerate(retrieved_output):
Expand Down
104 changes: 72 additions & 32 deletions adalflow/adalflow/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def custom_asdict(
tuples, lists, and dicts.
"""
if not is_dataclass_instance(obj):
raise TypeError("custom_asdict() should be called on dataclass instances")
raise TypeError(
"custom_asdict() should be called on dataclass instances")
return _asdict_inner(obj, dict_factory, exclude or {})


Expand Down Expand Up @@ -254,15 +255,18 @@ class TrecDataList:
): # Optional[Address] will be false, and true for each check

log.debug(
f"{is_dataclass(cls)} of {cls}, {is_potential_dataclass(cls)} of {cls}"
f"{is_dataclass(cls)} of {cls}, {
is_potential_dataclass(cls)} of {cls}"
)
# Ensure the data is a dictionary
if not isinstance(data, dict):
raise ValueError(
f"Expected data of type dict for {cls}, but got {type(data).__name__}"
f"Expected data of type dict for {
cls}, but got {type(data).__name__}"
)
cls_type = extract_dataclass_type(cls)
fieldtypes = {f.name: f.type for f in cls_type.__dataclass_fields__.values()}
fieldtypes = {
f.name: f.type for f in cls_type.__dataclass_fields__.values()}

restored_data = cls_type(
**{
Expand All @@ -277,11 +281,13 @@ class TrecDataList:
for item in data:
if check_data_class_field_args_zero(cls):
# restore the value to its dataclass type
restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.append(
dataclass_obj_from_dict(cls.__args__[0], item))

elif check_if_class_field_args_zero_exists(cls):
# Use the original data [Any]
restored_data.append(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.append(
dataclass_obj_from_dict(cls.__args__[0], item))

else:
restored_data.append(item)
Expand All @@ -293,10 +299,12 @@ class TrecDataList:
for item in data:
if check_data_class_field_args_zero(cls):
# restore the value to its dataclass type
restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.add(
dataclass_obj_from_dict(cls.__args__[0], item))
elif check_if_class_field_args_zero_exists(cls):
# Use the original data [Any]
restored_data.add(dataclass_obj_from_dict(cls.__args__[0], item))
restored_data.add(
dataclass_obj_from_dict(cls.__args__[0], item))

else:
# Use the original data [Any]
Expand All @@ -319,7 +327,8 @@ class TrecDataList:
return data
# else normal data like int, str, float, etc.
else:
log.debug(f"Not datclass, or list, or dict: {cls}, use the original data.")
log.debug(f"Not datclass, or list, or dict: {
cls}, use the original data.")
return data


Expand Down Expand Up @@ -393,7 +402,8 @@ def get_type_schema(
if arg is not type(None)
]
return (
f"Optional[{types[0]}]" if len(types) == 1 else f"Union[{', '.join(types)}]"
f"Optional[{types[0]}]" if len(
types) == 1 else f"Union[{', '.join(types)}]"
)
elif origin in {List, list}:
args = get_args(type_obj)
Expand All @@ -414,21 +424,22 @@ def get_type_schema(
elif origin in {Set, set}:
args = get_args(type_obj)
return (
f"Set[{get_type_schema(args[0],exclude, type_var_map)}]" if args else "Set"
f"Set[{get_type_schema(
args[0], exclude, type_var_map)}]" if args else "Set"
)

elif origin is Sequence:
args = get_args(type_obj)
return (
f"Sequence[{get_type_schema(args[0], exclude,type_var_map)}]"
f"Sequence[{get_type_schema(args[0], exclude, type_var_map)}]"
if args
else "Sequence"
)

elif origin in {Tuple, tuple}:
args = get_args(type_obj)
if args:
return f"Tuple[{', '.join(get_type_schema(arg,exclude,type_var_map) for arg in args)}]"
return f"Tuple[{', '.join(get_type_schema(arg, exclude, type_var_map) for arg in args)}]"
return "Tuple"

elif is_dataclass(type_obj):
Expand Down Expand Up @@ -496,7 +507,8 @@ def get_dataclass_schema(
# prepare field schema, it weill be done recursively for nested dataclasses

field_type = type_var_map.get(f.type, f.type)
field_schema = {"type": get_type_schema(field_type, exclude, type_var_map)}
field_schema = {"type": get_type_schema(
field_type, exclude, type_var_map)}

# check required field
is_required = _is_required_field(f)
Expand Down Expand Up @@ -588,7 +600,8 @@ def example_function(x: int, y: str = "default") -> int:
param_type = type_hints.get(param_name, "Any")
if parameter.default == Parameter.empty:
schema["required"].append(param_name)
schema["properties"][param_name] = {"type": get_type_schema(param_type)}
schema["properties"][param_name] = {
"type": get_type_schema(param_type)}
else:
schema["properties"][param_name] = {
"type": get_type_schema(param_type),
Expand Down Expand Up @@ -659,7 +672,8 @@ def evaluate_ast_node(node: ast.AST, context_map: Dict[str, Any] = None):
return output_fun
# TODO: raise the error back to the caller so that the llm can get the error message
except KeyError as e:
log.error(f"Error: {e}, {node.id} does not exist in the context_map.")
log.error(f"Error: {e}, {
node.id} does not exist in the context_map.")
raise ValueError(
f"Error: {e}, {node.id} does not exist in the context_map."
)
Expand All @@ -669,7 +683,8 @@ def evaluate_ast_node(node: ast.AST, context_map: Dict[str, Any] = None):

elif isinstance(
node, ast.Call
): # another fun or class as argument and value, e.g. add( multiply(4,5), 3)
# another fun or class as argument and value, e.g. add( multiply(4,5), 3)
):
func = evaluate_ast_node(node.func, context_map)
args = [evaluate_ast_node(arg, context_map) for arg in node.args]
kwargs = {
Expand Down Expand Up @@ -712,11 +727,13 @@ def parse_function_call_expr(
if isinstance(tree.body, ast.Call):
# Extract the function name
func_name = (
tree.body.func.id if isinstance(tree.body.func, ast.Name) else None
tree.body.func.id if isinstance(
tree.body.func, ast.Name) else None
)

# Prepare the list of arguments and keyword arguments
args = [evaluate_ast_node(arg, context_map) for arg in tree.body.args]
args = [evaluate_ast_node(arg, context_map)
for arg in tree.body.args]
keywords = {
kw.arg: evaluate_ast_node(kw.value, context_map)
for kw in tree.body.keywords
Expand Down Expand Up @@ -889,13 +906,32 @@ def is_normalized(v: VECTOR_TYPE, tol=1e-4) -> bool:
return np.abs(norm - 1) < tol


def normalize_np_array(v: np.ndarray) -> np.ndarray:
# Compute the norm of the vector (assuming v is 1D)
norm = np.linalg.norm(v)
# Normalize the vector
normalized_v = v / norm
# Return the normalized vector
return normalized_v
def normalize_embeddings(v: np.ndarray) -> np.ndarray:
"""
Normalize embeddings to have L2 norm = 1.


Handles both:
- 1D arrays: a single embedding (shape = (d,))
- 2D arrays: multiple embeddings (shape = (N, d))
"""
if v.ndim == 1:
# Single embedding vector
norm = np.linalg.norm(v)
if norm == 0:
norm = 1e-12 # Avoid division by zero
return (v / norm).astype(np.float32)
elif v.ndim == 2:
# Multiple embeddings: row-wise normalization
# norms: shape = (N,1)
norms = np.linalg.norm(v, axis=1, keepdims=True)
# Avoid division by zero for rows that might be zero
norms[norms < 1e-12] = 1e-12
return (v / norms).astype(np.float32)
else:
raise ValueError(
f"normalize_np_array expects 1D or 2D input. Got shape {v.shape}"
)


def normalize_vector(v: VECTOR_TYPE) -> List[float]:
Expand Down Expand Up @@ -1086,7 +1122,7 @@ def extract_json_str(text: str, add_missing_right_brace: bool = True) -> str:
"Incomplete JSON object found and add_missing_right_brace is False."
)

return text[start : end + 1]
return text[start: end + 1]


def extract_list_str(text: str, add_missing_right_bracket: bool = True) -> str:
Expand Down Expand Up @@ -1137,7 +1173,7 @@ def extract_list_str(text: str, add_missing_right_bracket: bool = True) -> str:
"Incomplete list found and add_missing_right_bracket is False."
)

return text[start : end + 1]
return text[start: end + 1]


def extract_yaml_str(text: str) -> str:
Expand Down Expand Up @@ -1222,7 +1258,8 @@ def parse_json_str_to_obj(json_str: str) -> Union[Dict[str, Any], List[Any]]:
return json_obj
except json.JSONDecodeError as e:
log.info(
f"Got invalid JSON object with json.loads. Error: {e}. Got JSON string: {json_str}"
f"Got invalid JSON object with json.loads. Error: {
e}. Got JSON string: {json_str}"
)
# 2nd attemp after fixing the json string
try:
Expand All @@ -1246,7 +1283,8 @@ def parse_json_str_to_obj(json_str: str) -> Union[Dict[str, Any], List[Any]]:
return json_obj
except yaml.YAMLError as e:
raise ValueError(
f"Got invalid JSON object with yaml.safe_load. Error: {e}. Got JSON string: {json_str}"
f"Got invalid JSON object with yaml.safe_load. Error: {
e}. Got JSON string: {json_str}"
)


Expand All @@ -1269,7 +1307,8 @@ def random_sample(

if not replace and num_shots > dataset_size:
log.debug(
f"num_shots {num_shots} is larger than the dataset size {dataset_size}"
f"num_shots {num_shots} is larger than the dataset size {
dataset_size}"
)
num_shots = dataset_size

Expand All @@ -1282,6 +1321,7 @@ def random_sample(
# Normalize weights to sum to 1
weights = weights / weights.sum()

indices = np.random.choice(len(dataset), size=num_shots, replace=replace, p=weights)
indices = np.random.choice(
len(dataset), size=num_shots, replace=replace, p=weights)

return [dataset[i] for i in indices]
Loading