Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: allow users to pass schema in encrypted data-frames #676

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 167 additions & 28 deletions src/concrete/ml/pandas/_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,70 @@

import copy
from collections import defaultdict
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import numpy
import pandas

from concrete.ml.pandas._development import get_min_max_allowed
from concrete.ml.quantization.quantizers import STABILITY_CONST

SCHEMA_FLOAT_KEYS = ["min", "max"]

def compute_scale_zero_point(column: pandas.Series, q_min: int, q_max: int) -> Tuple[float, float]:

def is_str_or_none(column: pandas.Series) -> bool:
"""Determine if the data-frames only contains string and None values or not.

Args:
column (pandas.Series): The data-frame to consider.

Returns:
bool: If the data-frames only contains string and None values or not.
"""
return column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all()


def compute_scale_zero_point(
f_min: float, f_max: float, q_min: int, q_max: int
) -> Tuple[float, float]:
"""Compute the scale and zero point to use for quantizing / de-quantizing the given column.

Note that the scale and zero point are computed so that values are quantized uniformly from
range [column.min(), column.max()] (float) to range [q_min, q_max] (int).
range [f_min, f_max] (float) to range [q_min, q_max] (int).

Args:
column (pandas.Series): The column to consider.
q_min (int): The minimum quantized value to consider.
q_max (int): The maximum quantized value to consider.
f_min (float): The minimum float value observed.
f_max (float): The maximum float value observed.
q_min (int): The minimum quantized value to target.
q_max (int): The maximum quantized value to target.

Returns:
Tuple[float, float]: The scale and zero-point.
"""
values_min, values_max = column.min(), column.max()

# If there si a single float value in the column, the scale and zero-point need to be handled
# differently
if values_max - values_min < STABILITY_CONST:
if f_max - f_min < STABILITY_CONST:

# If this single float value is 0, make sure it is not quantized to 0
if numpy.abs(values_max) < STABILITY_CONST:
if numpy.abs(f_max) < STABILITY_CONST:
scale = 1.0
zero_point = -q_min

# Else, quantize it to 1
else:
scale = 1 / values_max
scale = 1 / f_max
zero_point = 0

else:
scale = (q_max - q_min) / (values_max - values_min)
scale = (q_max - q_min) / (f_max - f_min)

# Zero-point must be rounded once NaN values are not represented by 0 anymore
# The issue is that we currently need to avoid quantized values to reach 0, but having a
# round here + in the 'quant' method can make this happen.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342
zero_point = values_min * scale - q_min
# Disable mypy until it is fixed
zero_point = f_min * scale - q_min # type: ignore[assignment]

return scale, zero_point

Expand Down Expand Up @@ -86,9 +103,49 @@ def dequant(
return x.astype(dtype)


# Provide a way for users to pass string mappings
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342
def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataFrame, Dict]:
def check_schema_format(pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None) -> None:
"""Check that the given schema has a proper expected format.

Args:
pandas_dataframe (pandas.DataFrame): The data-frame associated to the given schema.
schema (Optional[Dict]): The schema to check, which can be None. Default to None.

Raises:
ValueError: If the given schema is not a dict.
ValueError: If the given schema contains column names that do not appear in the data-frame.
ValueError: If one of the columns' mapping is not a dict.
"""
if schema is None:
return

if not isinstance(schema, dict):
raise ValueError(
"When set, parameter 'schema' must be a dictionary that associates some of the "
f"data-frame's column names to their value mappings. Got {type(schema)=}"
)

column_names = list(pandas_dataframe.columns)

for column_name, column_mapping in schema.items():
if column_name not in column_names:
# TODO: Is this check actually relevant ? Can't the schema provide more columns than the
# one found in the data-frame ?
raise ValueError(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we allow schema with column names that do not match the ones found in the given data-frame ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo raising the error as you are doing here is the correct behavior.

f"Column name '{column_name}' found in the given schema cannot be found in the "
f"input data-frame. Expected one of {column_names}"
)

if not isinstance(column_mapping, dict):
raise ValueError(
f"Mapping for column '{column_name}' is not a dictionary. Got "
f"{type(column_mapping)=}"
)


# pylint: disable=too-many-branches, too-many-statements
def pre_process_dtypes(
pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None
) -> Tuple[pandas.DataFrame, Dict]:
"""Pre-process the Pandas data-frame and check that input dtypes and ranges are supported.

Currently, three input dtypes are supported : integers (within a specific range), floating
Expand All @@ -97,6 +154,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

Args:
pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process.
schema (Optional[Dict]): The input schema to consider. Default to None.

Raises:
ValueError: If the values of a column with an integer dtype are out of bounds.
Expand All @@ -118,6 +176,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# Avoid sending column names to server, instead use hashes
# FIXME : https://github.com/zama-ai/concrete-ml-internal/issues/4342
# pylint: disable=too-many-nested-blocks
for column_name in pandas_dataframe.columns:
column = pandas_dataframe[column_name]
column_dtype = column.dtype
Expand All @@ -127,9 +186,13 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# If the column contains integers, make sure they are not out of bounds
if numpy.issubdtype(column_dtype, numpy.integer):
out_of_bounds = (column < q_min).any() or (column > q_max).any()
if schema is not None and column_name in schema:
raise ValueError(
f"Column '{column_name}' contains integer values and therefore does not "
"require any mappings. Please remove it"
)

if out_of_bounds:
if column.min() < q_min or column.max() > q_max:
raise ValueError(
f"Column '{column_name}' (dtype={column_dtype}) contains values that are out "
f"of bounds. Expected values to be in interval [min={q_min}, max={q_max}], but "
Expand All @@ -138,7 +201,33 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# If the column contains floats, quantize the values
elif numpy.issubdtype(column_dtype, numpy.floating):
scale, zero_point = compute_scale_zero_point(column, q_min, q_max)
if schema is not None and column_name in schema:
float_min_max = schema[column_name]

if not all(
float_mapping_key in SCHEMA_FLOAT_KEYS
for float_mapping_key in float_min_max.keys()
):
raise ValueError(
f"Column '{column_name}' contains float values but the associated mapping "
f"does not contain proper keys. Expected {sorted(SCHEMA_FLOAT_KEYS)}, but "
f"got {sorted(float_min_max.keys())}"
)

f_min, f_max = float_min_max["min"], float_min_max["max"]

if column.min() < f_min or column.max() > f_max:
raise ValueError(
f"Column '{column_name}' (dtype={column_dtype}) contains values that are "
f"out of bounds. Expected values to be in interval [min={f_min}, "
f"max={f_max}], as determined by the given schema, but found "
f"[min={column.min()}, max={column.max()}]."
)

else:
f_min, f_max = column.min(), column.max()

scale, zero_point = compute_scale_zero_point(f_min, f_max, q_min, q_max)

q_column = quant(column, scale, zero_point)

Expand All @@ -150,14 +239,58 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF

# If the column contains objects, make sure it is only made of strings or NaN values
elif column_dtype == "object":
is_str = column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all()

if is_str:

# Build a mapping between the unique strings values and integers
str_to_int = {
str_value: i + 1 for i, str_value in enumerate(column.dropna().unique())
}
if is_str_or_none(column):
if schema is not None and column_name in schema:
str_to_int = schema[column_name]

column_values = set(column.values) - set([None, numpy.NaN])
string_mapping_keys = set(str_to_int.keys())

# Allow custom mapping for NaN values once they are not represented by 0 anymore
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342
if numpy.NaN in string_mapping_keys:
raise ValueError(
f"String mapping for column '{column_name}' contains numpy.NaN as a "
"key, which is currently forbidden"
)

forgotten_string_values = column_values - string_mapping_keys

if forgotten_string_values:
raise ValueError(
f"String mapping keys for column '{column_name}' are not considering "
"all values from the data-frame. Missing values: "
f"{sorted(forgotten_string_values)}"
)

for string_mapping_key, string_mapping_value in str_to_int.items():
if not isinstance(string_mapping_value, int):
raise ValueError(
f"String mapping values for column '{column_name}' must be "
f"integers. Got {type(string_mapping_value)} for key "
f"{string_mapping_key}"
)

if string_mapping_value < q_min or string_mapping_value > q_max:
raise ValueError(
f"String mapping values for column '{column_name}' are out of "
f"bounds. Expected values to be in interval [min={q_min}, "
f"max={q_max}] but got {string_mapping_value} for key "
f"{string_mapping_key}"
)

if len(str_to_int.values()) != len(set(str_to_int.values())):
raise ValueError(
f"String mapping values for column '{column_name}' must be unique. Got "
f"{str_to_int.values()}"
)

else:

# Build a mapping between the unique strings values and integers
str_to_int = {
str_value: i + 1 for i, str_value in enumerate(column.dropna().unique())
}

# Make sure the number of unique values do not goes over the maximum integer value
# allowed in an encrypted data-frame
Expand Down Expand Up @@ -189,14 +322,20 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF
"supported."
)

# TODO: Should all non-integers columns be considered by the schema if not None ? Currently,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise an error/warning if all non-integer columns from the data-frame were not covered by the given schema ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if they are missing from the given schema?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are automatically computed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering because if could be an easy mistake to forget to put some columns, but no error will be raised

# mappings are computed automatically if schema is not set

return pandas_dataframe, dtype_mappings


def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.ndarray, Dict]:
def pre_process_from_pandas(
pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None
) -> Tuple[numpy.ndarray, Dict]:
"""Pre-process the Pandas data-frame.

Args:
pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process.
schema (Optional[Dict]): The input schema to consider. Default to None.

Raises:
ValueError: If the data-frame's index has not been reset (meaning the index is not a
Expand All @@ -217,7 +356,7 @@ def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.n
)

# Check that values are supported and build the mappings
q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe)
q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe, schema=schema)

# Replace NaN values with 0
# Remove this once NaN values are not represented by 0 anymore
Expand Down
18 changes: 14 additions & 4 deletions src/concrete/ml/pandas/client_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Define the framework used for managing keys (encrypt, decrypt) for encrypted data-frames."""

from pathlib import Path
from typing import Optional, Union
from typing import Dict, Optional, Union

import pandas

from concrete import fhe
from concrete.ml.pandas._development import CLIENT_PATH, get_encrypt_config
from concrete.ml.pandas._processing import post_process_to_pandas, pre_process_from_pandas
from concrete.ml.pandas._processing import (
check_schema_format,
post_process_to_pandas,
pre_process_from_pandas,
)
from concrete.ml.pandas._utils import decrypt_elementwise, encrypt_elementwise, encrypt_value
from concrete.ml.pandas.dataframe import EncryptedDataFrame

Expand Down Expand Up @@ -37,16 +41,22 @@ def keygen(self, keys_path: Optional[Union[Path, str]] = None):
else:
self.client.keygen(True)

def encrypt_from_pandas(self, pandas_dataframe: pandas.DataFrame) -> EncryptedDataFrame:
def encrypt_from_pandas(
self, pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a schema is optional. If set, it should follow a specific format.

if needed, we could also handle the output of get_schema (pandas data-frames) as an input here

) -> EncryptedDataFrame:
"""Encrypt a Pandas data-frame using the loaded client.

Args:
pandas_dataframe (DataFrame): The Pandas data-frame to encrypt.
schema (Optional[Dict]): The input schema to consider. Default to None.

Returns:
EncryptedDataFrame: The encrypted data-frame.
"""
pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe)

check_schema_format(pandas_dataframe, schema)

pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe, schema=schema)

# Inputs need to be encrypted element-wise in order to be able to use a composable circuit
# Once multi-operator is supported, better handle encryption configuration parameters
Expand Down
Loading
Loading