Skip to content

Commit

Permalink
cleaned up userdb
Browse files Browse the repository at this point in the history
  • Loading branch information
nfearnley committed May 21, 2024
1 parent 672dba3 commit 47a60b1
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 51 deletions.
5 changes: 3 additions & 2 deletions sizebot/cogs/loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import cast
from sizebot.lib.types import BotContext

import logging

Expand All @@ -7,11 +8,11 @@
import discord
from discord.ext import commands

from sizebot.lib.speed import MoveType
from sizebot.lib.stats import StatBox
from sizebot.lib.units import SV, TV
from sizebot.lib import userdb
from sizebot.lib.constants import emojis
from sizebot.lib.types import BotContext
from sizebot.lib.utils import pretty_time_delta
import sizebot.lib.language as lang

Expand Down Expand Up @@ -63,7 +64,7 @@ async def start(self, ctx: BotContext, action: str, stop: TV | None = None):
return

# Fix typing now that we've checked it
action = cast(userdb.MoveTypeStr, action)
action = cast(MoveType, action)

userdata = userdb.load(ctx.guild.id, ctx.author.id)

Expand Down
9 changes: 9 additions & 0 deletions sizebot/lib/macrovision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import sizebot.data
from sizebot.conf import conf
from sizebot.lib import errors
from sizebot.lib.digidecimal import Decimal
from sizebot.lib.stats import StatBox
from sizebot.lib.units import SV
Expand Down Expand Up @@ -132,3 +133,11 @@ def get_url(entities: list[MacrovisionEntity]) -> str:
base64_string = base64_bytes.decode("ascii")
url = f"https://macrovision.crux.sexy/?scene={base64_string}"
return url


def is_model(model: str) -> bool:
return model in model_heights


def is_modelview(model: str, view: str) -> bool:
return model in model_heights and view in model_heights[model]
5 changes: 5 additions & 0 deletions sizebot/lib/speed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import get_args, Literal

import math

from sizebot.lib.constants import emojis
Expand All @@ -7,6 +9,9 @@
from sizebot.lib.units import SV
from sizebot.lib.utils import pretty_time_delta

MoveType = Literal["walk", "run", "climb", "crawl", "swim"]
MOVETYPES = get_args(MoveType)


class Movement:
def __init__(self, key: str, dist: SV, *, stats: StatBox, showspeed: bool, steps: str | None = None, always_active: bool = False):
Expand Down
95 changes: 46 additions & 49 deletions sizebot/lib/userdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,25 @@
import discord

import sizebot.data
from sizebot.lib import errors, paths
from sizebot.lib import errors, macrovision, paths
from sizebot.lib.digidecimal import Decimal
from sizebot.lib.diff import Diff
from sizebot.lib.fakeplayer import FakePlayer
from sizebot.lib.gender import Gender
from sizebot.lib.speed import MoveType
from sizebot.lib.units import SV, TV, WV
from sizebot.lib.unitsystem import UnitSystem
from sizebot.lib.utils import truncate
from sizebot.lib.stats import AVERAGE_HEIGHT, AVERAGE_WEIGHT, PlayerStats

BASICALLY_ZERO = Decimal("1E-27")

modelJSON = json.loads(pkg_resources.read_text(sizebot.data, "models.json"))

MoveTypeStr = Literal["walk", "run", "climb", "crawl", "swim"]
MOVETYPES = get_args(MoveTypeStr)

MemberOrFake = discord.Member | FakePlayer
MemberOrFakeOrSize = MemberOrFake | SV


def str_or_none(v: Any) -> str | None:
def _str_or_none(v: Any) -> str | None:
if v is None:
return None
return str(v)
Expand Down Expand Up @@ -79,7 +76,7 @@ def __init__(self, guildid: int, id: int, nickname: str):
self.currentscalestep: Diff | None = None
self.currentscaletalk: Diff | None = None
self.scaletalklock: bool = False
self.currentmovetype: MoveTypeStr | None = None
self.currentmovetype: MoveType | None = None
self.movestarted: Arrow | None = None
self.movestop: TV | None = None
self.triggers: dict[str, Diff] = {}
Expand Down Expand Up @@ -178,24 +175,24 @@ def scale(self, scale: Decimal):
def stats(self) -> PlayerStats:
"""A bit of a patchwork solution for transitioning to BetterStats."""
return {
"height": str_or_none(self.baseheight),
"weight": str_or_none(self.baseweight),
"footlength": str_or_none(self.footlength),
"height": _str_or_none(self.baseheight),
"weight": _str_or_none(self.baseweight),
"footlength": _str_or_none(self.footlength),
"pawtoggle": self.pawtoggle,
"furtoggle": self.furtoggle,
"hairlength": str_or_none(self.hairlength),
"taillength": str_or_none(self.taillength),
"earheight": str_or_none(self.earheight),
"liftstrength": str_or_none(self.liftstrength),
"walkperhour": str_or_none(self.walkperhour),
"swimperhour": str_or_none(self.swimperhour),
"runperhour": str_or_none(self.runperhour),
"gender": str_or_none(self.gender), # TODO: Should this be autogender?
"scale": str_or_none(self.scale),
"nickname": str_or_none(self.nickname),
"id": str_or_none(self.id),
"macrovision_model": str_or_none(self.macrovision_model),
"macrovision_view": str_or_none(self.macrovision_view)
"hairlength": _str_or_none(self.hairlength),
"taillength": _str_or_none(self.taillength),
"earheight": _str_or_none(self.earheight),
"liftstrength": _str_or_none(self.liftstrength),
"walkperhour": _str_or_none(self.walkperhour),
"swimperhour": _str_or_none(self.swimperhour),
"runperhour": _str_or_none(self.runperhour),
"gender": _str_or_none(self.gender), # TODO: Should this be autogender?
"scale": _str_or_none(self.scale),
"nickname": _str_or_none(self.nickname),
"id": _str_or_none(self.id),
"macrovision_model": _str_or_none(self.macrovision_model),
"macrovision_view": _str_or_none(self.macrovision_view)
}

@property
Expand Down Expand Up @@ -228,7 +225,7 @@ def macrovision_model(self) -> str:

@macrovision_model.setter
def macrovision_model(self, value: str):
if value not in modelJSON.keys():
if not macrovision.is_model(self.macrovision_model):
raise errors.InvalidMacrovisionModelException(value)
self._macrovision_model = value

Expand All @@ -245,7 +242,7 @@ def macrovision_view(self) -> str:

@macrovision_view.setter
def macrovision_view(self, value: str):
if value not in modelJSON[self.macrovision_model].keys():
if not macrovision.is_modelview(self.macrovision_model, value):
raise errors.InvalidMacrovisionViewException(self.macrovision_model, value)
self._macrovision_view = value

Expand Down Expand Up @@ -293,34 +290,34 @@ def toJSON(self) -> Any:
# Create a new object from a python dictionary imported using json
@classmethod
def fromJSON(cls, jsondata: dict[str, Any]) -> User:
jsondata = migrate_json(jsondata)
jsondata = _migrate_json(jsondata)
userdata = User(int(jsondata["guildid"]), int(jsondata["id"]), cast(str, jsondata["nickname"]))
userdata.lastactive = optional_parse(arrow.get, jsondata["lastactive"])
userdata.lastactive = _optional_parse(arrow.get, jsondata["lastactive"])
userdata.picture_url = cast(str, jsondata["picture_url"])
userdata.description = cast(str, jsondata["description"])
userdata.gender = cast(Gender | None, jsondata["gender"])
userdata.display = cast(bool, jsondata["display"])
userdata.height = SV(jsondata["height"])
userdata.baseheight = SV(jsondata["baseheight"])
userdata.baseweight = WV(jsondata["baseweight"])
userdata.footlength = optional_parse(SV, jsondata["footlength"])
userdata.footlength = _optional_parse(SV, jsondata["footlength"])
userdata.pawtoggle = cast(bool, jsondata["pawtoggle"])
userdata.furtoggle = cast(bool, jsondata["furtoggle"])
userdata.hairlength = optional_parse(SV, jsondata["hairlength"])
userdata.taillength = optional_parse(SV, jsondata["taillength"])
userdata.earheight = optional_parse(SV, jsondata["earheight"])
userdata.liftstrength = optional_parse(WV, jsondata["liftstrength"])
userdata.walkperhour = optional_parse(SV, jsondata["walkperhour"])
userdata.runperhour = optional_parse(SV, jsondata["runperhour"])
userdata.swimperhour = optional_parse(SV, jsondata["swimperhour"])
userdata.currentscalestep = optional_parse(Diff.fromJSON, jsondata["currentscalestep"])
userdata.currentscaletalk = optional_parse(Diff.fromJSON, jsondata["currentscaletalk"])
userdata.hairlength = _optional_parse(SV, jsondata["hairlength"])
userdata.taillength = _optional_parse(SV, jsondata["taillength"])
userdata.earheight = _optional_parse(SV, jsondata["earheight"])
userdata.liftstrength = _optional_parse(WV, jsondata["liftstrength"])
userdata.walkperhour = _optional_parse(SV, jsondata["walkperhour"])
userdata.runperhour = _optional_parse(SV, jsondata["runperhour"])
userdata.swimperhour = _optional_parse(SV, jsondata["swimperhour"])
userdata.currentscalestep = _optional_parse(Diff.fromJSON, jsondata["currentscalestep"])
userdata.currentscaletalk = _optional_parse(Diff.fromJSON, jsondata["currentscaletalk"])
userdata.scaletalklock = cast(bool, jsondata["scaletalklock"])
userdata.currentmovetype = cast(MoveTypeStr | None, jsondata["currentmovetype"])
userdata.movestarted = optional_parse(arrow.get, jsondata["movestarted"])
userdata.movestop = optional_parse(TV, jsondata["movestop"])
userdata.currentmovetype = cast(MoveType | None, jsondata["currentmovetype"])
userdata.movestarted = _optional_parse(arrow.get, jsondata["movestarted"])
userdata.movestop = _optional_parse(TV, jsondata["movestop"])
userdata.triggers = {k: Diff.fromJSON(v) for k, v in cast(dict[str, str], jsondata["triggers"]).items()}
userdata.button = optional_parse(Diff.fromJSON, jsondata["button"])
userdata.button = _optional_parse(Diff.fromJSON, jsondata["button"])
userdata.tra_reports = cast(int, jsondata["tra_reports"])
userdata.unitsystem = cast(UnitSystem, jsondata["unitsystem"])
userdata.species = cast(str, jsondata["species"])
Expand Down Expand Up @@ -368,28 +365,28 @@ def from_height(cls, height: SV) -> User:
return User.from_fake(FakePlayer(height=height))


def get_guild_users_path(guildid: int) -> Path:
def _get_guild_users_path(guildid: int) -> Path:
return paths.guilddbpath / f"{guildid}" / "users"


def get_user_path(guildid: int, userid: int) -> Path:
return get_guild_users_path(guildid) / f"{userid}.json"
def _get_user_path(guildid: int, userid: int) -> Path:
return _get_guild_users_path(guildid) / f"{userid}.json"


def save(userdata: User):
guildid = userdata.guildid
userid = userdata.id
if guildid is None or userid is None:
raise errors.CannotSaveWithoutIDException
path = get_user_path(guildid, userid)
path = _get_user_path(guildid, userid)
path.parent.mkdir(exist_ok = True, parents = True)
jsondata = userdata.toJSON()
with open(path, "w") as f:
json.dump(jsondata, f, indent = 4)


def load(guildid: int, userid: int, *, member: discord.Member | None = None, allow_unreg: bool = False) -> User:
path = get_user_path(guildid, userid)
path = _get_user_path(guildid, userid)
try:
with open(path, "r") as f:
jsondata = json.load(f)
Expand All @@ -408,7 +405,7 @@ def load(guildid: int, userid: int, *, member: discord.Member | None = None, all


def delete(guildid: int, userid: int):
path = get_user_path(guildid, userid)
path = _get_user_path(guildid, userid)
path.unlink(missing_ok = True)


Expand Down Expand Up @@ -473,13 +470,13 @@ def load_or_fake_weight(arg: MemberOrFake | WV, *, allow_unreg: bool = False) ->
T = TypeVar("T")


def optional_parse(parser: Callable[[str], T], val: str | None) -> T | None:
def _optional_parse(parser: Callable[[str], T], val: str | None) -> T | None:
if val is None:
return None
return parser(val)


def migrate_json(jsondata: dict[str, Any]) -> dict[str, Any]:
def _migrate_json(jsondata: dict[str, Any]) -> dict[str, Any]:
if "allowchangefromothers" not in jsondata:
jsondata["allowchangefromothers"] = False
if "tra_reports" not in jsondata:
Expand Down

0 comments on commit 47a60b1

Please sign in to comment.