Skip to content

Commit

Permalink
Typing improvements + cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
einarf committed Nov 23, 2024
1 parent bf4655b commit 6dc4527
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 90 deletions.
33 changes: 19 additions & 14 deletions moderngl_window/loaders/scene/gltf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def load_glb(self):

version = struct.unpack("<I", fd.read(4))[0]
if version != 2:
raise ValueError(
f"{self.path} has unsupported version {version}"
)
raise ValueError(f"{self.path} has unsupported version {version}")

# Total file size including headers
_ = struct.unpack("<I", fd.read(4))[0] # noqa
Expand Down Expand Up @@ -500,9 +498,11 @@ def load(self, materials):
self.name,
vao=vao,
attributes=attributes,
material=materials[primitive.material]
if primitive.material is not None
else None,
material=(
materials[primitive.material]
if primitive.material is not None
else None
),
bbox_min=bbox_min,
bbox_max=bbox_max,
)
Expand Down Expand Up @@ -667,7 +667,8 @@ def __init__(self, view_id, data):

def read(self, byte_offset=0, dtype=None, count=0):
data = self.buffer.read(
byte_offset=byte_offset + self.byteOffset, byte_length=self.byteLength,
byte_offset=byte_offset + self.byteOffset,
byte_length=self.byteLength,
)
vbo = numpy.frombuffer(data, count=count, dtype=dtype)
return vbo
Expand Down Expand Up @@ -712,15 +713,15 @@ def open(self):
return

if self.has_data_uri:
self.data = base64.b64decode(self.uri[self.uri.find(",") + 1:])
self.data = base64.b64decode(self.uri[self.uri.find(",") + 1 :])
return

with open(str(self.path / self.uri), "rb") as fd:
self.data = fd.read()

def read(self, byte_offset=0, byte_length=0):
self.open()
return self.data[byte_offset:byte_offset + byte_length]
return self.data[byte_offset : byte_offset + byte_length]


class GLTFScene:
Expand All @@ -746,16 +747,16 @@ def __init__(self, data):
self.matrix = glm.mat4()

if self.translation is not None:
self.matrix = self.matrix * glm.translate(self.translation)
self.matrix = self.matrix * glm.translate(glm.vec3(*self.translation))

if self.rotation is not None:
quat = quaternion.create(
quat = glm.quat(
x=self.rotation[0],
y=self.rotation[1],
z=self.rotation[2],
w=self.rotation[3],
)
self.matrix = self.matrix * glm.transpose(glm.mat4(quat))
self.matrix = self.matrix * glm.mat4(quat)

if self.scale is not None:
self.matrix = self.matrix * glm.scale(self.scale)
Expand Down Expand Up @@ -803,7 +804,7 @@ def load(self, path):
image = Image.open(io.BytesIO(self.bufferView.read_raw()))
# Image is embedded
elif self.uri and self.uri.startswith("data:"):
data = self.uri[self.uri.find(",") + 1:]
data = self.uri[self.uri.find(",") + 1 :]
image = Image.open(io.BytesIO(base64.b64decode(data)))
logger.info("Loading embedded image")
else:
Expand All @@ -813,7 +814,11 @@ def load(self, path):

texture = t2d.Loader(
TextureDescription(
label="gltf", image=image, flip=False, mipmap=True, anisotropy=16.0,
label="gltf",
image=image,
flip=False,
mipmap=True,
anisotropy=16.0,
)
).load()

Expand Down
52 changes: 32 additions & 20 deletions moderngl_window/scene/node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""
Wrapper for a loaded mesh / vao with properties
"""

from __future__ import annotations

from typing import List, TYPE_CHECKING
import numpy

import glm
import moderngl

if TYPE_CHECKING:
from moderngl_window.scene import Camera, Mesh
Expand All @@ -15,7 +19,13 @@ class Node:
represents the scene tree.
"""

def __init__(self, name=None, camera=None, mesh=None, matrix=None):
def __init__(
self,
name: str | None = None,
camera: glm.mat4 | None = None,
mesh: Mesh | None = None,
matrix: glm.mat4 | None = None,
):
"""Create a node.
Keyword Args:
Expand All @@ -28,19 +38,19 @@ def __init__(self, name=None, camera=None, mesh=None, matrix=None):
self._camera = camera
self._mesh = mesh
# Local node matrix
self._matrix = matrix if matrix is not None else None
self._matrix = matrix
# Global matrix
self._matrix_global = None

self._children = []
self._children: list["Node"] = []

@property
def name(self) -> str:
"""str: Get or set the node name"""
return self._name

@name.setter
def name(self, value: str):
def name(self, value: str) -> None:
self._name = value

@property
Expand All @@ -49,7 +59,7 @@ def mesh(self) -> "Mesh":
return self._mesh

@mesh.setter
def mesh(self, value: "Mesh"):
def mesh(self, value: "Mesh") -> None:
self._mesh = value

@property
Expand All @@ -58,47 +68,47 @@ def camera(self) -> "Camera":
return self._camera

@camera.setter
def camera(self, value):
def camera(self, value: "Camera") -> None:
self._camera = value

@property
def matrix(self) -> numpy.ndarray:
def matrix(self) -> glm.mat4:
"""numpy.ndarray: Note matrix (local)"""
return self._matrix

@matrix.setter
def matrix(self, value):
def matrix(self, value: glm.mat4) -> None:
self._matrix = value

@property
def matrix_global(self) -> numpy.ndarray:
def matrix_global(self) -> glm.mat4:
"""numpy.ndarray: The global node matrix containing transformations from parent nodes"""
return self._matrix_global

@matrix_global.setter
def matrix_global(self, value):
def matrix_global(self, value: glm.mat4) -> None:
self._matrix_global = value

@property
def children(self) -> List["Node"]:
"""list: List of children"""
return self._children

def add_child(self, node):
def add_child(self, node: "Node") -> None:
"""Add a child to this node
Args:
node (Node): Node to add as a child
"""
self._children.append(node)

def draw(self, projection_matrix=None, camera_matrix=None, time=0):
def draw(self, projection_matrix: glm.mat4, camera_matrix: glm.mat4, time=0):
"""Draw node and children.
Keyword Args:
projection_matrix (bytes): projection matrix
camera_matrix (bytes): camera_matrix
time (float): The current time
projection_matrix: projection matrix
camera_matrix: camera_matrix
time: The current time
"""
if self._mesh:
self._mesh.draw(
Expand All @@ -115,7 +125,9 @@ def draw(self, projection_matrix=None, camera_matrix=None, time=0):
time=time,
)

def draw_bbox(self, projection_matrix, camera_matrix, program, vao):
def draw_bbox(
self, projection_matrix: glm.mat4, camera_matrix: glm.mat4, program: mod, vao
):
"""Draw bounding box around the node and children.
Keyword Args:
Expand Down Expand Up @@ -146,7 +158,7 @@ def draw_wireframe(self, projection_matrix, camera_matrix, program):
for child in self.children:
child.draw_wireframe(projection_matrix, self._matrix_global, program)

def calc_global_bbox(self, view_matrix, bbox_min, bbox_max):
def calc_global_bbox(self, view_matrix: glm.mat4, bbox_min, bbox_max) -> tuple:
"""Recursive calculation of scene bbox.
Keyword Args:
Expand All @@ -167,7 +179,7 @@ def calc_global_bbox(self, view_matrix, bbox_min, bbox_max):

return bbox_min, bbox_max

def calc_model_mat(self, model_matrix):
def calc_model_mat(self, model_matrix: glm.mat4) -> None:
"""Calculate the model matrix related to all parents.
Args:
Expand All @@ -184,5 +196,5 @@ def calc_model_mat(self, model_matrix):
for child in self._children:
child.calc_model_mat(model_matrix)

def __repr__(self):
def __repr__(self) -> str:
return "<Node name={}>".format(self.name)
Loading

0 comments on commit 6dc4527

Please sign in to comment.