-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmatrix.rkt
150 lines (131 loc) · 5.72 KB
/
matrix.rkt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#lang typed/racket
(provide (except-out (all-defined-out) det-2))
(require "tuples.rkt")
(define-type Matrix (Immutable-Vectorof (Immutable-Vectorof Float)))
(: mat
(-> Exact-Nonnegative-Integer
Exact-Nonnegative-Integer
(Immutable-Vectorof (Immutable-Vectorof Float))
Matrix))
(define (mat m n rows)
(if (and (= m (vector-length rows))
(andmap (lambda ([x : Integer]) (= x n)) (vector->list (vector-map vector-length rows))))
rows
(error "Illegal operation: input not m by n 2D immutable vector" rows)))
(: mat-m (-> Matrix Exact-Nonnegative-Integer))
(define (mat-m mat)
(vector-length mat))
(: mat-n (-> Matrix Exact-Nonnegative-Integer))
(define (mat-n mat)
(vector-length (vector-ref mat 0)))
(: mat-entry (-> Matrix Exact-Nonnegative-Integer Exact-Nonnegative-Integer Float))
(define (mat-entry mat m n)
(if (or (>= m (mat-m mat)) (>= n (mat-n mat)))
(error "Illegal operation: access matrix element out of bounds")
(vector-ref (vector-ref mat m) n)))
(: mat-row (-> Matrix Exact-Nonnegative-Integer (Immutable-Vectorof Float)))
(define (mat-row mat m)
(vector-ref mat m))
(: mat-col (-> Matrix Exact-Nonnegative-Integer (Immutable-Vectorof Float)))
(define (mat-col mat n)
(vector->immutable-vector
(for/vector: : (Mutable-Vectorof Float)
#:length (mat-m mat)
([row (in-vector mat)])
(vector-ref row n))))
(: mat= (-> Matrix Matrix Boolean))
(define (mat= m1 m2)
(for/and: : Boolean ([row1 (in-vector m1)] [row2 (in-vector m2)])
(for/and: : Boolean ([col1 (in-vector row1)] [col2 (in-vector row2)])
(f= col1 col2))))
(: build-matrix
(-> Exact-Nonnegative-Integer
Exact-Nonnegative-Integer
(-> Exact-Nonnegative-Integer Exact-Nonnegative-Integer Float)
Matrix))
(define (build-matrix m n f)
(cast ((inst vector->immutable-vector (Immutable-Vectorof Float))
(build-vector
m
(lambda ([row : Exact-Nonnegative-Integer])
(vector->immutable-vector
(build-vector n (lambda ([col : Exact-Nonnegative-Integer]) (f row col))))))) Matrix))
(: mat* (-> Matrix Matrix Matrix))
(define (mat* mat1 mat2)
(: dot* (-> (Immutable-Vectorof Float) (Immutable-Vectorof Float) Float))
(define (dot* v1 v2)
(for/fold ([sum 0.]) ([x (in-vector v1)] [y (in-vector v2)])
(+ sum (* x y))))
(let ([m1 : Exact-Nonnegative-Integer (mat-m mat1)]
[n1 : Exact-Nonnegative-Integer (mat-n mat1)]
[m2 : Exact-Nonnegative-Integer (mat-m mat2)]
[n2 : Exact-Nonnegative-Integer (mat-n mat2)])
(if (= n1 m2)
(build-matrix m1 n2
(lambda ([row : Exact-Nonnegative-Integer] [col : Exact-Nonnegative-Integer])
(dot* (mat-row mat1 row) (mat-col mat2 col))))
(error "Illegal operation: multiply matrices with incompatible sizes" mat1 mat2))))
(: mat-t* (-> Matrix Tuple Tuple))
(define (mat-t* m t)
(define-syntax-rule (dot* t1 t2)
(+ (* (tuple-x t1) (tuple-x t2))
(* (tuple-y t1) (tuple-y t2))
(* (tuple-z t1) (tuple-z t2))
(* (tuple-w t1) (tuple-w t2))))
(: row->tuple (-> (Immutable-Vectorof Float) Tuple))
(define (row->tuple row)
(tuple (vector-ref row 0) (vector-ref row 1) (vector-ref row 2) (vector-ref row 3)))
(let ([x (dot* (row->tuple (mat-row m 0)) t)]
[y (dot* (row->tuple (mat-row m 1)) t)]
[z (dot* (row->tuple (mat-row m 2)) t)]
[w (dot* (row->tuple (mat-row m 3)) t)])
(adaptive-tuple x y z w)))
(: id-mat (-> Exact-Nonnegative-Integer Matrix))
(define (id-mat n)
(build-matrix n
n
(lambda ([row : Exact-Nonnegative-Integer] [col : Exact-Nonnegative-Integer])
(if (= row col) 1. 0.))))
(: id-mat-4 Matrix)
(define id-mat-4
(mat 4 4 #[#[1. 0. 0. 0.] #[0. 1. 0. 0.] #[0. 0. 1. 0.] #[0. 0. 0. 1.]]))
(: transpose (-> Matrix Matrix))
(define (transpose mat)
(cast ((inst vector->immutable-vector (Immutable-Vectorof Float))
(build-vector (mat-n mat) (lambda ([y : Exact-Nonnegative-Integer]) (mat-col mat y)))) Matrix))
(: submat (-> Matrix Exact-Nonnegative-Integer Exact-Nonnegative-Integer Matrix))
(define (submat mat row col)
(let ([rows (vector-append (vector-take mat row) (vector-drop mat (add1 row)))])
(cast
(vector->immutable-vector
(for/vector ([y (in-vector rows)])
(vector->immutable-vector (vector-append (vector-take y col) (vector-drop y (add1 col))))))
Matrix)))
(: det-2 (-> Matrix Float))
(define (det-2 mat)
(- (* (mat-entry mat 0 0) (mat-entry mat 1 1)) (* (mat-entry mat 0 1) (mat-entry mat 1 0))))
(: det (-> Matrix Float))
(define (det mat)
(cond
[(and (= (mat-m mat) 2) (= (mat-n mat) 2)) (det-2 mat)]
[else
(for/fold ([sum : Float 0.] [col : Exact-Nonnegative-Integer 0] #:result sum)
([elem (in-vector (mat-row mat 0))])
(values
(+ sum (* elem ((if (even? col) identity -) (det (submat mat 0 col)))))
(add1 col)))]))
(: cofactor (-> Matrix Exact-Nonnegative-Integer Exact-Nonnegative-Integer Float))
(define (cofactor mat row col)
(if (or (>= row (mat-m mat)) (>= col (mat-n mat)))
(error "Illegal operation: calculate cofactor out of bounds" mat row col)
((if (even? (+ row col)) identity -) (det (submat mat row col)))))
(: inverse (-> Matrix Matrix))
(define (inverse mat)
(let ([m (mat-m mat)] [n (mat-n mat)] [determinant (det mat)])
(if (or (not (= m n)) (= 0. determinant))
(error "Illegal operation: matrix cannot be inverted" mat)
(transpose (build-matrix
n
n
(lambda ([row : Exact-Nonnegative-Integer] [col : Exact-Nonnegative-Integer])
(/ (cofactor mat row col) determinant)))))))