Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2Score Metric is incompatible with Evaluator Component #6817

Open
TomsCodingCode opened this issue May 23, 2024 · 3 comments
Open

R2Score Metric is incompatible with Evaluator Component #6817

TomsCodingCode opened this issue May 23, 2024 · 3 comments

Comments

@TomsCodingCode
Copy link

If the bug is related to a specific library below, please raise an issue in the
respective repo directly:

TensorFlow Data Validation Repo

TensorFlow Model Analysis Repo

TensorFlow Transform Repo

TensorFlow Serving Repo

System information

  • Have I specified the code to reproduce the issue (Yes, No): Yes
  • Environment in which the code is executed (e.g., Local(Linux/MacOS/Windows),
    Interactive Notebook, Google Cloud, etc): Windows + WSL 2 (Ubuntu)
  • TensorFlow version: 2.15.1
  • TFX Version: 1.15.0
  • Python version: 3.10.12
  • Python dependencies (from pip freeze output):
absl-py==1.4.0
annotated-types==0.6.0
anyio==4.3.0
apache-beam==2.55.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astunparse==1.6.3
async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
backcall==0.2.0
beautifulsoup4==4.12.3
bleach==6.1.0
cachetools==5.3.3
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==2.2.1
colorama==0.4.6
comm==0.2.2
coverage==7.5.0
cramjam==2.8.3
crcmod==1.7
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.1.1
distlib==0.3.8
dnspython==2.6.1
docker==4.4.4
docopt==0.6.2
docstring_parser==0.16
exceptiongroup==1.2.1
fastavro==1.9.4
fasteners==0.19
fastjsonschema==2.19.1
filelock==3.13.3
flatbuffers==24.3.25
fqdn==1.5.1
gast==0.5.4
google-api-core==2.18.0
google-api-python-client==1.12.11
google-apitools==0.5.31
google-auth==2.29.0
google-auth-httplib2==0.1.1
google-auth-oauthlib==1.2.0
google-cloud-aiplatform==1.49.0
google-cloud-bigquery==3.21.0
google-cloud-bigquery-storage==2.24.0
google-cloud-bigtable==2.23.1
google-cloud-core==2.4.1
google-cloud-datastore==2.19.0
google-cloud-dlp==3.16.0
google-cloud-language==2.13.3
google-cloud-pubsub==2.21.1
google-cloud-pubsublite==1.10.0
google-cloud-recommendations-ai==0.10.10
google-cloud-resource-manager==1.12.3
google-cloud-spanner==3.45.0
google-cloud-storage==2.16.0
google-cloud-videointelligence==2.13.3
google-cloud-vision==3.7.2
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
grpc-google-iam-v1==0.13.0
grpc-interceptor==0.15.4
grpcio==1.62.2
grpcio-status==1.48.2
h11==0.14.0
h5py==3.11.0
hdfs==2.7.3
httpcore==1.0.5
httplib2==0.22.0
httpx==0.27.0
idna==3.7
iniconfig==2.0.0
ipykernel==6.29.4
ipython==7.34.0
ipython-genutils==0.2.0
ipywidgets==7.8.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
joblib==1.4.0
Js2Py==0.74
json5==0.9.25
jsonpickle==3.0.4
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.8
jupyterlab-widgets==1.1.7
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
keras==2.15.0
keras-tuner==1.4.7
kt-legacy==1.0.5
kubernetes==12.0.1
libclang==18.1.1
lxml==5.2.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
ml-dtypes==0.3.2
ml-metadata==1.15.0
ml-pipelines-sdk==1.15.0
namex==0.0.7
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
nltk==3.8.1
notebook==7.1.3
notebook_shim==0.2.4
numpy==1.26.4
oauth2client==4.1.3
oauthlib==3.2.2
objsize==0.7.0
opt-einsum==3.3.0
optree==0.10.0
orjson==3.10.1
overrides==7.7.0
packaging==24.0
pandas==1.5.3
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.3.0
platformdirs==4.2.1
plotly==5.22.0
pluggy==1.4.0
portalocker==2.8.2
portpicker==1.6.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
proto-plus==1.23.0
protobuf==3.20.3
psutil==5.9.8
ptyprocess==0.7.0
pyarrow==10.0.1
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==2.7.1
pydantic_core==2.18.2
pydot==1.4.2
pyfarmhash==0.3.2
Pygments==2.17.2
pyjsparser==2.7.1
pymongo==4.7.0
pyparsing==3.1.2
pytest==8.1.1
pytest-timeout==2.3.1
pytest_httpserver==1.0.10
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
python-snappy==0.7.1
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.2
referencing==0.35.0
regex==2024.4.28
requests==2.31.0
requests-oauthlib==2.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rouge_score==0.1.2
rpds-py==0.18.0
rsa==4.9
sacrebleu==2.4.2
scikit-learn==1.4.2
scipy==1.12.0
Send2Trash==1.8.3
shapely==2.0.4
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
sqlparse==0.5.0
tabulate==0.9.0
tenacity==8.3.0
tensorboard==2.15.2
tensorboard-data-server==0.7.2
tensorflow==2.15.1
tensorflow-data-validation==1.15.1
tensorflow-estimator==2.15.0
tensorflow-hub==0.15.0
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-metadata==1.15.0
tensorflow-serving-api==2.15.1
tensorflow-transform==1.15.0
tensorflow_model_analysis==0.46.0
termcolor==2.4.0
terminado==0.18.1
tfx==1.15.0
tfx-bsl==1.15.1
threadpoolctl==3.5.0
tinycss2==1.3.0
tomli==2.0.1
tornado==6.4
tqdm==4.66.2
traitlets==5.14.3
types-python-dateutil==2.9.0.20240316
typing_extensions==4.11.0
tzdata==2024.1
tzlocal==5.2
uri-template==1.3.0
uritemplate==3.0.1
urllib3==2.2.1
virtualenv==20.25.1
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.2
widgetsnbextension==3.6.6
wrapt==1.14.1
zstandard==0.22.0

Describe the current behavior
When I specify R2 Score as a Metric when compiling the model, the downstream component Evaluator throws an exception.
This is not the case for any other Metrics I have tried

Describe the expected behavior
The Evaluator evaluates the model with all stadard metrics

Standalone code to reproduce the issue

import os

import pandas as pd
import tensorflow_model_analysis as tfma
import tfx.v1 as tfx


try:
  url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'
  column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model_Year', 'Origin']

  dataset = pd.read_csv(url, names=column_names,
                          na_values='?', comment='\t',
                          sep=' ', skipinitialspace=True)
  dataset = dataset.dropna()
  dataset -= dataset.mean()
  dataset /= dataset.std()
  os.mkdir('./data')
  dataset.to_csv('data/data.csv', index=False)
except:
  pass

with open('./trainer.py', 'w') as f:
  f.write("""

import keras
import tensorflow as tf
import tfx.v1 as tfx
from tensorflow_metadata.proto.v0 import schema_pb2
from tfx_bsl.public import tfxio

column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',
                'Acceleration', 'Model_Year', 'Origin']

def _input_fn(file_pattern: list[str],
              data_accessor: tfx.components.DataAccessor,
              schema: schema_pb2.Schema,
              batch_size: int = 200) -> tf.data.Dataset:
  # from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple#create_a_pipeline

  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key='MPG'),
      schema=schema).repeat()

def run_fn(fn_args: tfx.components.FnArgs):
  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      tfx.utils.parse_pbtxt_file(fn_args.schema_path, schema_pb2.Schema()))

  layers = [keras.Input(shape=(1,), name=n) for n in column_names if n != 'MPG']

  linear_model = keras.layers.concatenate(layers)

  linear_model = keras.Model(inputs=layers, outputs=[keras.layers.Dense(units=1)(linear_model)])

  linear_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.1),
    loss='mean_absolute_error',
    metrics=[keras.metrics.R2Score()]
    )
  
  linear_model.fit(
    train_dataset,
    steps_per_epoch=1000,
    epochs=1)
  
  linear_model.save(fn_args.serving_model_dir, save_format='tf')

""")

example_gen = tfx.components.CsvExampleGen(input_base='data')
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(statistics=statistics_gen.outputs['statistics'])

trainer = tfx.components.Trainer(
  module_file='trainer.py',
  examples=example_gen.outputs['examples'],
  schema=schema_gen.outputs['schema']
)


evaluator = tfx.components.Evaluator(
  examples=example_gen.outputs['examples'],
  schema=schema_gen.outputs['schema'],
  model=trainer.outputs['model'],
  eval_config=tfma.EvalConfig(
    model_specs=[tfma.ModelSpec(label_key='MPG')],
    slicing_specs=[tfma.SlicingSpec()],
    metrics_specs=tfma.metrics.default_regression_specs()
  )
)

pipeline = tfx.dsl.Pipeline(
  pipeline_name='min',
  pipeline_root='min',
  metadata_connection_config=tfx.orchestration.metadata
    .sqlite_metadata_connection_config('mlmd.db'),
  components=[
    example_gen, statistics_gen, schema_gen,
    trainer,
    evaluator],
  enable_cache=False)


tfx.orchestration.LocalDagRunner().run(pipeline)

Providing a bare minimum test case or step(s) to reproduce the problem will
greatly help us to debug the issue. If possible, please share a link to
Colab/Jupyter/any notebook.

Name of your Organization (Optional)

Other info / logs
I went down a debugging rabbithole myself and I think the issue is that the metrics container used by the model does not build the metrics it contains after being loaded from memory. For most metrics this is fine, but the R2 Score adds some weights during the build function and those are missing, as can be seen in the error message.
You called set_weights(weights) on layer "r2_score" with a weight list of length 5, but the layer was expecting 1 weights.

Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

@lego0901
Copy link
Member

Hi @TomsCodingCode, thanks for using TFX and report the issue with a concrete standalone example!

The Evaluator standard component redirects its implementation into TFMA packages. Although, I am not an expert on TFMA, at my glance, the issue arises because the tf.keras.metrics.R2Score metric stores multiple internal variables (like sum of squares, sample count, etc.) However, TFMA usually expects metrics to have a simple, single-value state for serialization and aggregation. This mismatch causes the error you encountered.

To resolve this, you can use a custom metric wrapper called R2ScoreWrapper. This wrapper encapsulates the complex internal state of R2Score and exposes only the final value to TFMA, making it compatible with TFMA's serialization and aggregation mechanisms.

class R2ScoreWrapper(tf.keras.metrics.Metric):
  def __init__(self, name="r2_score_wrapper", **kwargs):
    super().__init__(name=name, **kwargs)
    self.r2_score = tf.keras.metrics.R2Score()

  def update_state(self, y_true, y_pred, sample_weight=None):
    self.r2_score.update_state(y_true, y_pred, sample_weight)

  def result(self):
    return self.r2_score.result()

  def reset_state(self):
    self.r2_score.reset_state()

...
  linear_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.1),
    loss='mean_absolute_error',
    metrics=[R2ScoreWrapper()]
    # metrics=[tf.keras.metrics.R2Score()]
    )

I found the similar phenomenon appears to the metrics.F1Score too. Hope that it works as you wish. Thanks!

@TomsCodingCode
Copy link
Author

That works perfectly fine as a workaround, thanks!

Is this a bug in tfma then or is this expected baviour?

@janasangeetha
Copy link
Contributor

Hi @TomsCodingCode
As mentioned in comment, this is an expected behavior. Request to close the issue as the solution works.
Please feel free to reach out to us if required!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants