-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: allow users to pass schema in encrypted data-frames #676
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,53 +2,70 @@ | |
|
||
import copy | ||
from collections import defaultdict | ||
from typing import Dict, List, Tuple | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import numpy | ||
import pandas | ||
|
||
from concrete.ml.pandas._development import get_min_max_allowed | ||
from concrete.ml.quantization.quantizers import STABILITY_CONST | ||
|
||
SCHEMA_FLOAT_KEYS = ["min", "max"] | ||
|
||
def compute_scale_zero_point(column: pandas.Series, q_min: int, q_max: int) -> Tuple[float, float]: | ||
|
||
def is_str_or_none(column: pandas.Series) -> bool: | ||
"""Determine if the data-frames only contains string and None values or not. | ||
|
||
Args: | ||
column (pandas.Series): The data-frame to consider. | ||
|
||
Returns: | ||
bool: If the data-frames only contains string and None values or not. | ||
""" | ||
return column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all() | ||
|
||
|
||
def compute_scale_zero_point( | ||
f_min: float, f_max: float, q_min: int, q_max: int | ||
) -> Tuple[float, float]: | ||
"""Compute the scale and zero point to use for quantizing / de-quantizing the given column. | ||
|
||
Note that the scale and zero point are computed so that values are quantized uniformly from | ||
range [column.min(), column.max()] (float) to range [q_min, q_max] (int). | ||
range [f_min, f_max] (float) to range [q_min, q_max] (int). | ||
|
||
Args: | ||
column (pandas.Series): The column to consider. | ||
q_min (int): The minimum quantized value to consider. | ||
q_max (int): The maximum quantized value to consider. | ||
f_min (float): The minimum float value observed. | ||
f_max (float): The maximum float value observed. | ||
q_min (int): The minimum quantized value to target. | ||
q_max (int): The maximum quantized value to target. | ||
|
||
Returns: | ||
Tuple[float, float]: The scale and zero-point. | ||
""" | ||
values_min, values_max = column.min(), column.max() | ||
|
||
# If there si a single float value in the column, the scale and zero-point need to be handled | ||
# differently | ||
if values_max - values_min < STABILITY_CONST: | ||
if f_max - f_min < STABILITY_CONST: | ||
|
||
# If this single float value is 0, make sure it is not quantized to 0 | ||
if numpy.abs(values_max) < STABILITY_CONST: | ||
if numpy.abs(f_max) < STABILITY_CONST: | ||
scale = 1.0 | ||
zero_point = -q_min | ||
|
||
# Else, quantize it to 1 | ||
else: | ||
scale = 1 / values_max | ||
scale = 1 / f_max | ||
zero_point = 0 | ||
|
||
else: | ||
scale = (q_max - q_min) / (values_max - values_min) | ||
scale = (q_max - q_min) / (f_max - f_min) | ||
|
||
# Zero-point must be rounded once NaN values are not represented by 0 anymore | ||
# The issue is that we currently need to avoid quantized values to reach 0, but having a | ||
# round here + in the 'quant' method can make this happen. | ||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 | ||
zero_point = values_min * scale - q_min | ||
# Disable mypy until it is fixed | ||
zero_point = f_min * scale - q_min # type: ignore[assignment] | ||
|
||
return scale, zero_point | ||
|
||
|
@@ -86,9 +103,49 @@ def dequant( | |
return x.astype(dtype) | ||
|
||
|
||
# Provide a way for users to pass string mappings | ||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 | ||
def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataFrame, Dict]: | ||
def check_schema_format(pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None) -> None: | ||
"""Check that the given schema has a proper expected format. | ||
|
||
Args: | ||
pandas_dataframe (pandas.DataFrame): The data-frame associated to the given schema. | ||
schema (Optional[Dict]): The schema to check, which can be None. Default to None. | ||
|
||
Raises: | ||
ValueError: If the given schema is not a dict. | ||
ValueError: If the given schema contains column names that do not appear in the data-frame. | ||
ValueError: If one of the columns' mapping is not a dict. | ||
""" | ||
if schema is None: | ||
return | ||
|
||
if not isinstance(schema, dict): | ||
raise ValueError( | ||
"When set, parameter 'schema' must be a dictionary that associates some of the " | ||
f"data-frame's column names to their value mappings. Got {type(schema)=}" | ||
) | ||
|
||
column_names = list(pandas_dataframe.columns) | ||
|
||
for column_name, column_mapping in schema.items(): | ||
if column_name not in column_names: | ||
# TODO: Is this check actually relevant ? Can't the schema provide more columns than the | ||
# one found in the data-frame ? | ||
raise ValueError( | ||
f"Column name '{column_name}' found in the given schema cannot be found in the " | ||
f"input data-frame. Expected one of {column_names}" | ||
) | ||
|
||
if not isinstance(column_mapping, dict): | ||
raise ValueError( | ||
f"Mapping for column '{column_name}' is not a dictionary. Got " | ||
f"{type(column_mapping)=}" | ||
) | ||
|
||
|
||
# pylint: disable=too-many-branches, too-many-statements | ||
def pre_process_dtypes( | ||
pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None | ||
) -> Tuple[pandas.DataFrame, Dict]: | ||
"""Pre-process the Pandas data-frame and check that input dtypes and ranges are supported. | ||
|
||
Currently, three input dtypes are supported : integers (within a specific range), floating | ||
|
@@ -97,6 +154,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF | |
|
||
Args: | ||
pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process. | ||
schema (Optional[Dict]): The input schema to consider. Default to None. | ||
|
||
Raises: | ||
ValueError: If the values of a column with an integer dtype are out of bounds. | ||
|
@@ -118,6 +176,7 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF | |
|
||
# Avoid sending column names to server, instead use hashes | ||
# FIXME : https://github.com/zama-ai/concrete-ml-internal/issues/4342 | ||
# pylint: disable=too-many-nested-blocks | ||
for column_name in pandas_dataframe.columns: | ||
column = pandas_dataframe[column_name] | ||
column_dtype = column.dtype | ||
|
@@ -127,9 +186,13 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF | |
|
||
# If the column contains integers, make sure they are not out of bounds | ||
if numpy.issubdtype(column_dtype, numpy.integer): | ||
out_of_bounds = (column < q_min).any() or (column > q_max).any() | ||
if schema is not None and column_name in schema: | ||
raise ValueError( | ||
f"Column '{column_name}' contains integer values and therefore does not " | ||
"require any mappings. Please remove it" | ||
) | ||
|
||
if out_of_bounds: | ||
if column.min() < q_min or column.max() > q_max: | ||
raise ValueError( | ||
f"Column '{column_name}' (dtype={column_dtype}) contains values that are out " | ||
f"of bounds. Expected values to be in interval [min={q_min}, max={q_max}], but " | ||
|
@@ -138,7 +201,33 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF | |
|
||
# If the column contains floats, quantize the values | ||
elif numpy.issubdtype(column_dtype, numpy.floating): | ||
scale, zero_point = compute_scale_zero_point(column, q_min, q_max) | ||
if schema is not None and column_name in schema: | ||
float_min_max = schema[column_name] | ||
|
||
if not all( | ||
float_mapping_key in SCHEMA_FLOAT_KEYS | ||
for float_mapping_key in float_min_max.keys() | ||
): | ||
raise ValueError( | ||
f"Column '{column_name}' contains float values but the associated mapping " | ||
f"does not contain proper keys. Expected {sorted(SCHEMA_FLOAT_KEYS)}, but " | ||
f"got {sorted(float_min_max.keys())}" | ||
) | ||
|
||
f_min, f_max = float_min_max["min"], float_min_max["max"] | ||
|
||
if column.min() < f_min or column.max() > f_max: | ||
raise ValueError( | ||
f"Column '{column_name}' (dtype={column_dtype}) contains values that are " | ||
f"out of bounds. Expected values to be in interval [min={f_min}, " | ||
f"max={f_max}], as determined by the given schema, but found " | ||
f"[min={column.min()}, max={column.max()}]." | ||
) | ||
|
||
else: | ||
f_min, f_max = column.min(), column.max() | ||
|
||
scale, zero_point = compute_scale_zero_point(f_min, f_max, q_min, q_max) | ||
|
||
q_column = quant(column, scale, zero_point) | ||
|
||
|
@@ -150,14 +239,58 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF | |
|
||
# If the column contains objects, make sure it is only made of strings or NaN values | ||
elif column_dtype == "object": | ||
is_str = column.apply(lambda x: isinstance(x, str) or not pandas.notna(x)).all() | ||
|
||
if is_str: | ||
|
||
# Build a mapping between the unique strings values and integers | ||
str_to_int = { | ||
str_value: i + 1 for i, str_value in enumerate(column.dropna().unique()) | ||
} | ||
if is_str_or_none(column): | ||
if schema is not None and column_name in schema: | ||
str_to_int = schema[column_name] | ||
|
||
column_values = set(column.values) - set([None, numpy.NaN]) | ||
string_mapping_keys = set(str_to_int.keys()) | ||
|
||
# Allow custom mapping for NaN values once they are not represented by 0 anymore | ||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4342 | ||
if numpy.NaN in string_mapping_keys: | ||
raise ValueError( | ||
f"String mapping for column '{column_name}' contains numpy.NaN as a " | ||
"key, which is currently forbidden" | ||
) | ||
|
||
forgotten_string_values = column_values - string_mapping_keys | ||
|
||
if forgotten_string_values: | ||
raise ValueError( | ||
f"String mapping keys for column '{column_name}' are not considering " | ||
"all values from the data-frame. Missing values: " | ||
f"{sorted(forgotten_string_values)}" | ||
) | ||
|
||
for string_mapping_key, string_mapping_value in str_to_int.items(): | ||
if not isinstance(string_mapping_value, int): | ||
raise ValueError( | ||
f"String mapping values for column '{column_name}' must be " | ||
f"integers. Got {type(string_mapping_value)} for key " | ||
f"{string_mapping_key}" | ||
) | ||
|
||
if string_mapping_value < q_min or string_mapping_value > q_max: | ||
raise ValueError( | ||
f"String mapping values for column '{column_name}' are out of " | ||
f"bounds. Expected values to be in interval [min={q_min}, " | ||
f"max={q_max}] but got {string_mapping_value} for key " | ||
f"{string_mapping_key}" | ||
) | ||
|
||
if len(str_to_int.values()) != len(set(str_to_int.values())): | ||
raise ValueError( | ||
f"String mapping values for column '{column_name}' must be unique. Got " | ||
f"{str_to_int.values()}" | ||
) | ||
|
||
else: | ||
|
||
# Build a mapping between the unique strings values and integers | ||
str_to_int = { | ||
str_value: i + 1 for i, str_value in enumerate(column.dropna().unique()) | ||
} | ||
|
||
# Make sure the number of unique values do not goes over the maximum integer value | ||
# allowed in an encrypted data-frame | ||
|
@@ -189,14 +322,20 @@ def pre_process_dtypes(pandas_dataframe: pandas.DataFrame) -> Tuple[pandas.DataF | |
"supported." | ||
) | ||
|
||
# TODO: Should all non-integers columns be considered by the schema if not None ? Currently, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we raise an error/warning if all non-integer columns from the data-frame were not covered by the given schema ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would happen if they are missing from the given schema? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. they are automatically computed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just wondering because if could be an easy mistake to forget to put some columns, but no error will be raised |
||
# mappings are computed automatically if schema is not set | ||
|
||
return pandas_dataframe, dtype_mappings | ||
|
||
|
||
def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.ndarray, Dict]: | ||
def pre_process_from_pandas( | ||
pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None | ||
) -> Tuple[numpy.ndarray, Dict]: | ||
"""Pre-process the Pandas data-frame. | ||
|
||
Args: | ||
pandas_dataframe (pandas.DataFrame): The Pandas data-frame to pre-process. | ||
schema (Optional[Dict]): The input schema to consider. Default to None. | ||
|
||
Raises: | ||
ValueError: If the data-frame's index has not been reset (meaning the index is not a | ||
|
@@ -217,7 +356,7 @@ def pre_process_from_pandas(pandas_dataframe: pandas.DataFrame) -> Tuple[numpy.n | |
) | ||
|
||
# Check that values are supported and build the mappings | ||
q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe) | ||
q_pandas_dataframe, dtype_mappings = pre_process_dtypes(pandas_dataframe, schema=schema) | ||
|
||
# Replace NaN values with 0 | ||
# Remove this once NaN values are not represented by 0 anymore | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
"""Define the framework used for managing keys (encrypt, decrypt) for encrypted data-frames.""" | ||
|
||
from pathlib import Path | ||
from typing import Optional, Union | ||
from typing import Dict, Optional, Union | ||
|
||
import pandas | ||
|
||
from concrete import fhe | ||
from concrete.ml.pandas._development import CLIENT_PATH, get_encrypt_config | ||
from concrete.ml.pandas._processing import post_process_to_pandas, pre_process_from_pandas | ||
from concrete.ml.pandas._processing import ( | ||
check_schema_format, | ||
post_process_to_pandas, | ||
pre_process_from_pandas, | ||
) | ||
from concrete.ml.pandas._utils import decrypt_elementwise, encrypt_elementwise, encrypt_value | ||
from concrete.ml.pandas.dataframe import EncryptedDataFrame | ||
|
||
|
@@ -37,16 +41,22 @@ def keygen(self, keys_path: Optional[Union[Path, str]] = None): | |
else: | ||
self.client.keygen(True) | ||
|
||
def encrypt_from_pandas(self, pandas_dataframe: pandas.DataFrame) -> EncryptedDataFrame: | ||
def encrypt_from_pandas( | ||
self, pandas_dataframe: pandas.DataFrame, schema: Optional[Dict] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a schema is optional. If set, it should follow a specific format. if needed, we could also handle the output of |
||
) -> EncryptedDataFrame: | ||
"""Encrypt a Pandas data-frame using the loaded client. | ||
|
||
Args: | ||
pandas_dataframe (DataFrame): The Pandas data-frame to encrypt. | ||
schema (Optional[Dict]): The input schema to consider. Default to None. | ||
|
||
Returns: | ||
EncryptedDataFrame: The encrypted data-frame. | ||
""" | ||
pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe) | ||
|
||
check_schema_format(pandas_dataframe, schema) | ||
|
||
pandas_array, dtype_mappings = pre_process_from_pandas(pandas_dataframe, schema=schema) | ||
|
||
# Inputs need to be encrypted element-wise in order to be able to use a composable circuit | ||
# Once multi-operator is supported, better handle encryption configuration parameters | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we allow schema with column names that do not match the ones found in the given data-frame ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo raising the error as you are doing here is the correct behavior.