Skip to content

Commit

Permalink
refactor Christoffel_matrix function
Browse files Browse the repository at this point in the history
- now wavevectors can be of arrays of shape (n,3)
  • Loading branch information
marcoalopez committed Jun 21, 2024
1 parent 078b650 commit fb496ca
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 58 deletions.
45 changes: 28 additions & 17 deletions src/christoffel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Filename: christoffel.py #
# Description: TODO #
# #
# Copyright (c) 2023 #
# Copyright (c) 2024 #
# #
# PyRockWave is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published #
Expand Down Expand Up @@ -193,47 +193,58 @@ def _normalize_vector(vector: np.ndarray):
return vector / magnitude


def _christoffel_matrix(wavevector: np.ndarray, Cijkl: np.ndarray):
def _christoffel_matrix(wavevectors: np.ndarray, Cijkl: np.ndarray) -> np.ndarray:
"""Calculate the Christoffel matrix for a given wave vector
and elastic tensor Cij.
Parameters
----------
wave_vector : numpy.ndarray
The wave vector as a 1D NumPy array of length 3.
wavevectors : numpy.ndarray
The wave vectors normalized to lie on the unit sphere as a
1D or 2D NumPy array.
If 1D, shape must be (3,).
If 2D, shape must be (n, 3).
Cijkl : numpy.ndarray
The elastic tensor as a 4D NumPy array of shape (3, 3, 3, 3).
Returns
-------
numpy.ndarray
The Christoffel matrix as a 2D NumPy array of shape (3, 3).
Raises
------
ValueError
If wave_vector is not a 1D NumPy array of length 3, or
if Cij is not a 4D NumPy array of shape (3, 3, 3, 3).
The Christoffel matrix as a 2D NumPy array of shape (n, 3, 3).
Notes
-----
The Christoffel matrix is calculated using the formula
M = k @ Cijkl @ k, where M is the Christoffel matrix, k is the
M = k @ Cijkl @ k, where M is the Christoffel matrix, k is a
wave vector, and Cijkl is the elastic tensor (stiffness matrix).
"""

# Validate input parameters
if not isinstance(wavevector, np.ndarray) or wavevector.shape != (3,):
raise ValueError("wave_vector should be a 1D NumPy array of length 3.")
if not isinstance(wavevectors, np.ndarray):
raise ValueError("wavevectors should be a NumPy array.")
if wavevectors.ndim == 1 and wavevectors.shape[0] == 3:
wavevectors = wavevectors.reshape(1, 3)
elif wavevectors.ndim == 2 and wavevectors.shape[1] == 3:
pass
else:
raise ValueError(
"wavevectors should be a NumPy array of shape (3,) if 1D or (n, 3) if 2D."
)

if not isinstance(Cijkl, np.ndarray) or Cijkl.shape != (3, 3, 3, 3):
raise ValueError("Cijkl should be a 4D NumPy array of shape (3, 3, 3, 3).")

# normalize wavevector to lie on unit sphere
wave_vector = _normalize_vector(wavevector)
# get the number of wave vectors
n = wavevectors.shape[0]

# initialize array (pre-allocate)
Mij = np.zeros((n, 3, 3))

for i in range(n):
Mij[i, :, :] = np.dot(wavevectors[i, :], np.dot(wavevectors[i, :], Cijkl))

return np.dot(wave_vector, np.dot(wave_vector, Cijkl))
return Mij


def _calc_eigen(Mij: np.ndarray):
Expand Down
Loading

0 comments on commit fb496ca

Please sign in to comment.