Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new models #427

Merged
merged 4 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions common/src/models/hps_model/hps_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,51 @@
#include "models/hps_model/hps_model_2021_09_20_tiled.h"
#include "models/hps_model/hps_model_2022_01_05_74ops.h"
#include "models/hps_model/hps_model_2022_01_05_89ops.h"
#include "models/hps_model/presence_0_320_240_1_20220117_140437_201k_26k_96ops.h"
#include "models/hps_model/second_0_320_240_1_20220117_135512_201k_26k_96ops.h"
#include "tflite.h"

namespace {

int loaded_model = 0;

// Initialize models
// Initialize model
void do_init_09_20(void) {
puts("Loading HPS 09_20 model");
tflite_load_model(hps_model_2021_09_20_tiled, hps_model_2021_09_20_tiled_len);
loaded_model = 0;
}

// Initialize models
// Initialize model
void do_init_01_05_74ops(void) {
puts("Loading HPS 01_05_74ops model");
tflite_load_model(hps_model_2022_01_05_74ops, hps_model_2022_01_05_74ops_len);
loaded_model = 1;
}

// Initialize models
// Initialize model
void do_init_01_05_89ops(void) {
puts("Loading HPS 01_05_89ops model");
tflite_load_model(hps_model_2022_01_05_89ops, hps_model_2022_01_05_89ops_len);
loaded_model = 2;
}

// Initialize model
void do_init_presence_2022017_96ops(void) {
puts("Loading Presence 20220117 96ops model");
tflite_load_model(presence_0_320_240_1_20220117_140437_201k_26k_96ops,
presence_0_320_240_1_20220117_140437_201k_26k_96ops_len);
loaded_model = 3;
}

// Initialize model
void do_init_second_2022017_96ops(void) {
puts("Loading Second 20220117 96ops model");
tflite_load_model(second_0_320_240_1_20220117_135512_201k_26k_96ops,
second_0_320_240_1_20220117_135512_201k_26k_96ops_len);
loaded_model = 4;
}

// Run classification and interpret results
int32_t classify() {
tflite_classify();
Expand Down Expand Up @@ -84,13 +102,13 @@ void do_classify_zeros() { printf("Result is %ld\n", classify_zeros()); }
struct GoldenTest {
int32_t (*fn)();
const char* name;
int32_t expected[3];
int32_t expected[5];
};

GoldenTest golden_tests[4] = {
{classify_cat, "cat", {-77, -47, -47}},
{classify_diagram, "diagram", {-124, -123, -123}},
{classify_zeros, "zeroes", {-126, -128, -128}},
{classify_cat, "cat", {-77, -47, -47, -101, -117}},
{classify_diagram, "diagram", {-124, -123, -123, -124, -127}},
{classify_zeros, "zeroes", {-126, -128, -128, -128, -128}},
{nullptr, "", 0},
};

Expand Down Expand Up @@ -130,6 +148,10 @@ struct Menu MENU = {
do_init_01_05_74ops),
MENU_ITEM('2', "Reinitialize with 01_05_89ops model",
do_init_01_05_89ops),
MENU_ITEM('3', "Reinitialize with presence_2022017_96ops model",
do_init_presence_2022017_96ops),
MENU_ITEM('4', "Reinitialize with second_2022017_96ops model",
do_init_second_2022017_96ops),
MENU_END,
},
};
Expand Down
Binary file not shown.
Binary file not shown.
14 changes: 5 additions & 9 deletions proj/hps_accel/gateware/gen2/hps_cfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ def elab(self, m):
m.d.sync += self.filter_output.valid.eq(0)
m.d.sync += self.post_process_params.valid.eq(0)

# post process parameters
pp_bias = Signal(signed(16))
pp_shift = Signal(unsigned(4))

# All sets take exactly one cycle
m.d.sync += self.done.eq(0)

Expand Down Expand Up @@ -182,13 +178,13 @@ def elab(self, m):
with m.Case(Constants.REG_NUM_OUTPUT_VALUES):
m.d.sync += self.config.num_output_values.eq(self.in0)
with m.Case(Constants.REG_POST_PROCESS_BIAS):
m.d.sync += pp_bias.eq(self.in0s)
m.d.sync += self.post_process_params.payload.bias.eq(
self.in0s)
with m.Case(Constants.REG_POST_PROCESS_SHIFT):
m.d.sync += pp_shift.eq(self.in0)
m.d.sync += self.post_process_params.payload.shift.eq(
self.in0)
with m.Case(Constants.REG_POST_PROCESS_MULTIPLIER):
m.d.sync += [
self.post_process_params.payload.bias.eq(pp_bias),
self.post_process_params.payload.shift.eq(pp_shift),
self.post_process_params.payload.multiplier.eq(
self.in0s),
self.post_process_params.valid.eq(1),
Expand Down Expand Up @@ -220,7 +216,7 @@ def max_(word0, word1):
m.d.sync += self.done.eq(0)
with m.If(self.start):
this2 = Signal(32)
m.d.comb += this2.eq(max_(self.in0,self.in1))
m.d.comb += this2.eq(max_(self.in0, self.in1))
m.d.sync += self.output.eq(max_(this2, last2))
m.d.sync += last2.eq(this2)
m.d.sync += self.done.eq(1)
Expand Down
2 changes: 1 addition & 1 deletion proj/hps_accel/gateware/gen2/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def delay(m, signal, cycles):


POST_PROCESS_PARAMS = [
('bias', signed(16)),
('bias', signed(18)),
('multiplier', signed(32)),
('shift', unsigned(4)),
]
Expand Down
Loading