forked from analogdevicesinc/msdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_riscv.c
149 lines (122 loc) · 4.69 KB
/
main_riscv.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/******************************************************************************
*
* Copyright (C) 2022-2023 Maxim Integrated Products, Inc. (now owned by
* Analog Devices, Inc.),
* Copyright (C) 2023-2024 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
******************************************************************************/
// mnist-riscv
// Created using ai8xize.py --test-dir sdk/Examples/MAX78000/CNN --prefix mnist-riscv --checkpoint-file trained/ai85-mnist-qat8-q.pth.tar --config-file networks/mnist-chw-ai85.yaml --softmax --device MAX78000 --timer 0 --display-checkpoint --verbose --riscv --riscv-debug
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <stdio.h>
#include "mxc.h"
#include "fcr_regs.h"
#include "sema_regs.h"
#include "cnn.h"
#include "sampledata.h"
#include "sampleoutput.h"
volatile uint32_t cnn_time; // Stopwatch
void fail(void)
{
printf("\n*** FAIL ***\n\n");
while (1) {}
}
// 1-channel 28x28 data input (784 bytes / 196 32-bit words):
// CHW 28x28, channel 0
__attribute__((section(".rvflash_section"))) static const uint32_t input_0[] = SAMPLE_INPUT_0;
void load_input(void)
{
// This function loads the sample data input -- replace with actual data
memcpy32((uint32_t *)0x50400000, input_0, 196);
}
// Expected output of layer 4 for mnist-riscv given the sample input (known-answer test)
// Delete this function for production code
static const uint32_t sample_output[] = SAMPLE_OUTPUT;
int check_output(void)
{
int i;
uint32_t mask, len;
volatile uint32_t *addr;
const uint32_t *ptr = sample_output;
while ((addr = (volatile uint32_t *)*ptr++) != 0) {
mask = *ptr++;
len = *ptr++;
for (i = 0; i < len; i++)
if ((*addr++ & mask) != *ptr++) {
printf("Data mismatch (%d/%d) at address 0x%08x: Expected 0x%08x, read 0x%08x.\n",
i + 1, len, addr - 1, *(ptr - 1), *(addr - 1) & mask);
return CNN_FAIL;
}
}
return CNN_OK;
}
// Classification layer:
static int32_t ml_data[CNN_NUM_OUTPUTS];
static q15_t ml_softmax[CNN_NUM_OUTPUTS];
void softmax_layer(void)
{
cnn_unload((uint32_t *)ml_data);
softmax_q17p14_q15((const q31_t *)ml_data, CNN_NUM_OUTPUTS, ml_softmax);
}
int main(void)
{
int i;
int digs, tens;
Debug_Init(); // Set up RISCV JTAG
MXC_ICC_Enable(MXC_ICC1); // Enable cache
// Enable peripheral, enable CNN interrupt, turn on CNN clock
// CNN clock: APB (50 MHz) div 1
cnn_enable(MXC_S_GCR_PCLKDIV_CNNCLKSEL_PCLK, MXC_S_GCR_PCLKDIV_CNNCLKDIV_DIV1);
printf("\n*** CNN Inference Test ***\n");
cnn_init(); // Bring state machine into consistent state
cnn_load_weights(); // Load kernels
cnn_load_bias();
cnn_configure(); // Configure state machine
load_input(); // Load data input
cnn_start(); // Start CNN processing
while (cnn_time == 0) asm volatile("wfi"); // Wait for CNN
if (check_output() != CNN_OK)
fail();
softmax_layer();
printf("\n*** PASS ***\n\n");
#ifdef CNN_INFERENCE_TIMER
printf("Approximate inference time: %u us\n\n", cnn_time);
#endif
cnn_disable(); // Shut down CNN clock, disable peripheral
printf("Classification results:\n");
for (i = 0; i < CNN_NUM_OUTPUTS; i++) {
digs = (1000 * ml_softmax[i] + 0x4000) >> 15;
tens = digs % 10;
digs = digs / 10;
printf("[%7d] -> Class %d: %d.%d%%\n", ml_data[i], i, digs, tens);
}
// Signal the Cortex-M4
MXC_SEMA->irq0 = MXC_F_SEMA_IRQ0_EN | MXC_F_SEMA_IRQ0_CM4_IRQ;
return 0;
}
/*
SUMMARY OF OPS
Hardware: 10,883,968 ops (10,751,808 macc; 128,576 comp; 3,584 add; 0 mul; 0 bitwise)
Layer 0: 470,400 ops (423,360 macc; 47,040 comp; 0 add; 0 mul; 0 bitwise)
Layer 1: 8,356,800 ops (8,294,400 macc; 62,400 comp; 0 add; 0 mul; 0 bitwise)
Layer 2: 1,954,304 ops (1,935,360 macc; 18,944 comp; 0 add; 0 mul; 0 bitwise)
Layer 3: 100,544 ops (96,768 macc; 192 comp; 3,584 add; 0 mul; 0 bitwise)
Layer 4: 1,920 ops (1,920 macc; 0 comp; 0 add; 0 mul; 0 bitwise)
RESOURCE USAGE
Weight memory: 71,148 bytes out of 442,368 bytes total (16%)
Bias memory: 10 bytes out of 2,048 bytes total (0%)
*/