Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

root get weights: html, caching, slicing #29

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions bin/btcli

This file was deleted.

32 changes: 31 additions & 1 deletion bittensor_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,20 @@ def root_get_weights(
self,
network: Optional[str] = Options.network,
chain: Optional[str] = Options.chain,
limit_min_col: Optional[int] = typer.Option(
None,
"--limit-min-col",
"--min",
help="Limit left display of the table to this column.",
),
limit_max_col: Optional[int] = typer.Option(
None,
"--limit-max-col",
"--max",
help="Limit right display of the table to this column.",
),
reuse_last: bool = Options.reuse_last,
html_output: bool = Options.html_output,
):
"""
# root get-weights
Expand Down Expand Up @@ -1844,8 +1858,24 @@ def root_get_weights(
network. It offers transparency into how network rewards and responsibilities are allocated across different
subnets.
"""
if (reuse_last or html_output) and self.config.get("no_cache") is True:
err_console.print(
"Unable to use `--reuse-last` or `--html` when config no-cache is set."
)
raise typer.Exit()
if not reuse_last:
subtensor = self.initialize_chain(network, chain)
else:
subtensor = None
return self._run_command(
root.get_weights(self.initialize_chain(network, chain))
root.get_weights(
subtensor,
limit_min_col,
limit_max_col,
reuse_last,
html_output,
self.config.get("no_cache", False),
)
)

def root_boost(
Expand Down
149 changes: 101 additions & 48 deletions bittensor_cli/src/commands/root.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import TypedDict, Optional
import json
from typing import TypedDict, Optional, cast
from rich import box
import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -30,6 +31,10 @@
get_delegates_details_from_github,
convert_weight_uids_and_vals_to_tensor,
ss58_to_vec_u8,
create_table,
render_table,
update_metadata_table,
get_metadata_table,
)
from bittensor_cli.src import Constants

Expand Down Expand Up @@ -818,68 +823,116 @@ async def set_weights(
)


async def get_weights(subtensor: SubtensorInterface):
async def get_weights(
subtensor: SubtensorInterface,
limit_min_col: Optional[int],
limit_max_col: Optional[int],
reuse_last: bool,
html_output: bool,
no_cache: bool,
):
"""Get weights for root network."""
with console.status(":satellite: Synchronizing with chain..."):
weights = await subtensor.weights(0)
if not reuse_last:
with console.status(":satellite: Synchronizing with chain..."):
weights = await subtensor.weights(0)

uid_to_weights: dict[int, dict] = {}
netuids = set()
for matrix in weights:
[uid, weights_data] = matrix
uid_to_weights: dict[int, dict] = {}
netuids = set()
for matrix in weights:
[uid, weights_data] = matrix

if not len(weights_data):
uid_to_weights[uid] = {}
normalized_weights = []
else:
normalized_weights = np.array(weights_data)[:, 1] / max(
np.sum(weights_data, axis=0)[1], 1
if not len(weights_data):
uid_to_weights[uid] = {}
normalized_weights = []
else:
normalized_weights = np.array(weights_data)[:, 1] / max(
np.sum(weights_data, axis=0)[1], 1
)

for weight_data, normalized_weight in zip(weights_data, normalized_weights):
[netuid, _] = weight_data
netuids.add(netuid)
if uid not in uid_to_weights:
uid_to_weights[uid] = {}

uid_to_weights[uid][netuid] = normalized_weight
rows: list[list[str]] = []
for uid in uid_to_weights:
row = [str(uid)]

uid_weights = uid_to_weights[uid]
for netuid in netuids:
if netuid in uid_weights:
row.append("{:0.2f}%".format(uid_weights[netuid] * 100))
else:
row.append("~")
rows.append(row)

if not no_cache:
db_cols = [("UID", "INTEGER")]
for netuid in netuids:
db_cols.append((f"_{netuid}", "TEXT"))
create_table("rootgetweights", db_cols, rows)
netuids = list(netuids)
update_metadata_table(
"rootgetweights",
{"rows": json.dumps(rows), "netuids": json.dumps(netuids)},
)
else:
metadata = get_metadata_table("rootgetweights")
rows = json.loads(metadata["rows"])
netuids = json.loads(metadata["netuids"])

for weight_data, normalized_weight in zip(weights_data, normalized_weights):
[netuid, _] = weight_data
netuids.add(netuid)
if uid not in uid_to_weights:
uid_to_weights[uid] = {}
_min_lim = limit_min_col if limit_min_col is not None else 0
_max_lim = limit_max_col + 1 if limit_max_col is not None else len(netuids)
_max_lim = min(_max_lim, len(netuids))

uid_to_weights[uid][netuid] = normalized_weight
if _min_lim is not None and _min_lim > len(netuids):
err_console.print("Minimum limit greater than number of netuids")
return

table = Table(
show_footer=True,
box=None,
pad_edge=False,
width=None,
title="[white]Root Network Weights",
)
table.add_column(
"[white]UID",
header_style="overline white",
footer_style="overline white",
style="rgb(50,163,219)",
no_wrap=True,
)
for netuid in netuids:
if not html_output:
table = Table(
show_footer=True,
box=None,
pad_edge=False,
width=None,
title="[white]Root Network Weights",
)
table.add_column(
f"[white]{netuid}",
"[white]UID",
header_style="overline white",
footer_style="overline white",
justify="right",
style="green",
style="rgb(50,163,219)",
no_wrap=True,
)

for uid in uid_to_weights:
row = [str(uid)]
for netuid in cast(list, netuids)[_min_lim:_max_lim]:
table.add_column(
f"[white]{netuid}",
header_style="overline white",
footer_style="overline white",
justify="right",
style="green",
no_wrap=True,
)

uid_weights = uid_to_weights[uid]
for netuid in netuids:
if netuid in uid_weights:
row.append("{:0.2f}%".format(uid_weights[netuid] * 100))
else:
row.append("~")
table.add_row(*row)
# Adding rows
for row in rows:
new_row = [row[0]] + row[_min_lim:_max_lim]
table.add_row(*new_row)

return console.print(table)
return console.print(table)

else:
html_cols = [{"title": "UID", "field": "UID"}]
for netuid in netuids[_min_lim:_max_lim]:
html_cols.append({"title": str(netuid), "field": f"_{netuid}"})
render_table(
"rootgetweights",
"Root Network Weights",
html_cols,
)


async def _get_my_weights(
Expand Down
6 changes: 5 additions & 1 deletion bittensor_cli/src/commands/wallets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,11 @@ async def _filter_stake_info(stake_info: StakeInfo) -> bool:


async def transfer(
wallet: Wallet, subtensor: SubtensorInterface, destination: str, amount: float, prompt: bool
wallet: Wallet,
subtensor: SubtensorInterface,
destination: str,
amount: float,
prompt: bool,
):
"""Transfer token of amount to destination."""
await transfer_extrinsic(
Expand Down
7 changes: 6 additions & 1 deletion bittensor_cli/src/subtensor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
)
from bittensor_cli.src.bittensor.balances import Balance
from bittensor_cli.src import Constants, defaults, TYPE_REGISTRY
from bittensor_cli.src.utils import ss58_to_vec_u8, format_error_message, console, err_console
from bittensor_cli.src.utils import (
ss58_to_vec_u8,
format_error_message,
console,
err_console,
)


class ParamWithTypes(TypedDict):
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import pytest

from bittensor_cli.src.bittensor.async_substrate_interface import AsyncSubstrateInterface
from bittensor_cli.src.bittensor.async_substrate_interface import (
AsyncSubstrateInterface,
)


# Fixture for setting up and tearing down a localnet.sh chain between tests
Expand Down
12 changes: 7 additions & 5 deletions tests/e2e_tests/test_staking_sudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_staking(local_chain):
wallet_alice.name,
"--network",
"local",
"--no-prompt"
"--no-prompt",
],
)
assert f"✅ Registered subnetwork with netuid: {netuid}" in result.stdout
Expand All @@ -77,7 +77,7 @@ def test_staking(local_chain):
netuid,
"--chain",
"ws://127.0.0.1:9945",
"--no-prompt"
"--no-prompt",
],
)
assert "✅ Registered" in register_subnet.stdout
Expand All @@ -99,7 +99,7 @@ def test_staking(local_chain):
"ws://127.0.0.1:9945",
"--amount",
"100",
"--no-prompt"
"--no-prompt",
],
)
assert "✅ Finalized" in add_stake.stdout
Expand All @@ -120,7 +120,9 @@ def test_staking(local_chain):
],
)
# Assert correct stake is added
cleaned_stake = [re.sub(r'\s+', ' ', line) for line in show_stake.stdout.splitlines()]
cleaned_stake = [
re.sub(r"\s+", " ", line) for line in show_stake.stdout.splitlines()
]
stake_added = cleaned_stake[3].split()[4].strip("τ")
assert Balance.from_tao(100) == Balance.from_tao(float(stake_added))

Expand All @@ -146,7 +148,7 @@ def test_staking(local_chain):
"ws://127.0.0.1:9945",
"--amount",
"100",
"--no-prompt"
"--no-prompt",
],
)
assert "✅ Finalized" in remove_stake.stdout
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e_tests/test_wallet_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_wallet_overview_inspect(local_chain):
wallet.name,
"--network",
"local",
"--no-prompt"
"--no-prompt",
],
)
assert f"✅ Registered subnetwork with netuid: {netuid}" in result.stdout
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_wallet_overview_inspect(local_chain):
"1",
"--chain",
"ws://127.0.0.1:9945",
"--no-prompt"
"--no-prompt",
],
)
assert "✅ Registered" in register_subnet.stdout
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_wallet_transfer(local_chain):
"local",
"--amount",
"100",
"--no-prompt"
"--no-prompt",
],
)

Expand Down Expand Up @@ -350,7 +350,7 @@ def test_wallet_transfer(local_chain):
"local",
"--amount",
"100",
"--no-prompt"
"--no-prompt",
],
)

Expand Down
Loading