Skip to content

Commit 31e82d6

Browse files
author
twitter-team
committedApr 28, 2023
improvements from external prs
-fix corner case where dr converter failed when initializing Closes twitter#550
1 parent 23fa75d commit 31e82d6

31 files changed

+304
-243
lines changed
 

‎navi/dr_transform/src/all_config.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ pub struct RenamedFeatures {
4444
}
4545

4646
pub fn parse(json_str: &str) -> Result<AllConfig, Error> {
47-
serde_json::from_str(json_str)
47+
let all_config: AllConfig = serde_json::from_str(json_str)?;
48+
Ok(all_config)
4849
}

‎navi/dr_transform/src/converter.rs

+146-116
Large diffs are not rendered by default.

‎navi/navi/proto/tensorflow/core/framework/full_type.proto

+5-5
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ enum FullTypeId {
122122
// TFT_TENSOR[TFT_INT32, TFT_UNKNOWN]
123123
// is a Tensor of int32 element type and unknown shape.
124124
//
125-
// TODO: Define TFT_SHAPE and add more examples.
125+
// TODO(mdan): Define TFT_SHAPE and add more examples.
126126
TFT_TENSOR = 1000;
127127

128128
// Array (or tensorflow::TensorList in the variant type registry).
@@ -178,7 +178,7 @@ enum FullTypeId {
178178
// object (for now).
179179

180180
// The bool element type.
181-
// TODO
181+
// TODO(mdan): Quantized types, legacy representations (e.g. ref)
182182
TFT_BOOL = 200;
183183
// Integer element types.
184184
TFT_UINT8 = 201;
@@ -195,7 +195,7 @@ enum FullTypeId {
195195
TFT_DOUBLE = 211;
196196
TFT_BFLOAT16 = 215;
197197
// Complex element types.
198-
// TODO: Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
198+
// TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
199199
TFT_COMPLEX64 = 212;
200200
TFT_COMPLEX128 = 213;
201201
// The string element type.
@@ -240,7 +240,7 @@ enum FullTypeId {
240240
// ownership is in the true sense: "the op argument representing the lock is
241241
// available".
242242
// Mutex locks are the dynamic counterpart of control dependencies.
243-
// TODO: Properly document this thing.
243+
// TODO(mdan): Properly document this thing.
244244
//
245245
// Parametrization: TFT_MUTEX_LOCK[].
246246
TFT_MUTEX_LOCK = 10202;
@@ -271,6 +271,6 @@ message FullTypeDef {
271271
oneof attr {
272272
string s = 3;
273273
int64 i = 4;
274-
// TODO: list/tensor, map? Need to reconcile with TFT_RECORD, etc.
274+
// TODO(mdan): list/tensor, map? Need to reconcile with TFT_RECORD, etc.
275275
}
276276
}

‎navi/navi/proto/tensorflow/core/framework/function.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ message FunctionDefLibrary {
2323
// with a value. When a GraphDef has a call to a function, it must
2424
// have binding for every attr defined in the signature.
2525
//
26-
// TODO:
26+
// TODO(zhifengc):
2727
// * device spec, etc.
2828
message FunctionDef {
2929
// The definition of the function's name, arguments, return values,

‎navi/navi/proto/tensorflow/core/framework/node_def.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ message NodeDef {
6161
// one of the names from the corresponding OpDef's attr field).
6262
// The values must have a type matching the corresponding OpDef
6363
// attr's type field.
64-
// TODO: Add some examples here showing best practices.
64+
// TODO(josh11b): Add some examples here showing best practices.
6565
map<string, AttrValue> attr = 5;
6666

6767
message ExperimentalDebugInfo {

‎navi/navi/proto/tensorflow/core/framework/op_def.proto

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ message OpDef {
9696
// Human-readable description.
9797
string description = 4;
9898

99-
// TODO: bool is_optional?
99+
// TODO(josh11b): bool is_optional?
100100

101101
// --- Constraints ---
102102
// These constraints are only in effect if specified. Default is no
@@ -139,7 +139,7 @@ message OpDef {
139139
// taking input from multiple devices with a tree of aggregate ops
140140
// that aggregate locally within each device (and possibly within
141141
// groups of nearby devices) before communicating.
142-
// TODO: Implement that optimization.
142+
// TODO(josh11b): Implement that optimization.
143143
bool is_aggregate = 16; // for things like add
144144

145145
// Other optimizations go here, like

‎navi/navi/proto/tensorflow/core/framework/step_stats.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ message MemoryStats {
5353

5454
// Time/size stats recorded for a single execution of a graph node.
5555
message NodeExecStats {
56-
// TODO: Use some more compact form of node identity than
56+
// TODO(tucker): Use some more compact form of node identity than
5757
// the full string name. Either all processes should agree on a
5858
// global id (cost_id?) for each node, or we should use a hash of
5959
// the name.

‎navi/navi/proto/tensorflow/core/framework/tensor.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framewo
1616
message TensorProto {
1717
DataType dtype = 1;
1818

19-
// Shape of the tensor. TODO: sort out the 0-rank issues.
19+
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
2020
TensorShapeProto tensor_shape = 2;
2121

2222
// Only one of the representations below is set, one of "tensor_contents" and

‎navi/navi/proto/tensorflow/core/protobuf/config.proto

+4-4
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ message ConfigProto {
532532

533533
// We removed the flag client_handles_error_formatting. Marking the tag
534534
// number as reserved.
535-
// TODO: Should we just remove this tag so that it can be
535+
// TODO(shikharagarwal): Should we just remove this tag so that it can be
536536
// used in future for other purpose?
537537
reserved 2;
538538

@@ -576,7 +576,7 @@ message ConfigProto {
576576
// - If isolate_session_state is true, session states are isolated.
577577
// - If isolate_session_state is false, session states are shared.
578578
//
579-
// TODO: Add a single API that consistently treats
579+
// TODO(b/129330037): Add a single API that consistently treats
580580
// isolate_session_state and ClusterSpec propagation.
581581
bool share_session_state_in_clusterspec_propagation = 8;
582582

@@ -704,7 +704,7 @@ message ConfigProto {
704704

705705
// Options for a single Run() call.
706706
message RunOptions {
707-
// TODO Turn this into a TraceOptions proto which allows
707+
// TODO(pbar) Turn this into a TraceOptions proto which allows
708708
// tracing to be controlled in a more orthogonal manner?
709709
enum TraceLevel {
710710
NO_TRACE = 0;
@@ -781,7 +781,7 @@ message RunMetadata {
781781
repeated GraphDef partition_graphs = 3;
782782

783783
message FunctionGraphs {
784-
// TODO: Include some sort of function/cache-key identifier?
784+
// TODO(nareshmodi): Include some sort of function/cache-key identifier?
785785
repeated GraphDef partition_graphs = 1;
786786

787787
GraphDef pre_optimization_graph = 2;

‎navi/navi/proto/tensorflow/core/protobuf/coordination_service.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ service CoordinationService {
194194

195195
// Report error to the task. RPC sets the receiving instance of coordination
196196
// service agent to error state permanently.
197-
// TODO: Consider splitting this into a different RPC service.
197+
// TODO(b/195990880): Consider splitting this into a different RPC service.
198198
rpc ReportErrorToAgent(ReportErrorToAgentRequest)
199199
returns (ReportErrorToAgentResponse);
200200

‎navi/navi/proto/tensorflow/core/protobuf/debug.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ message DebugTensorWatch {
4646
// are to be debugged, the callers of Session::Run() must use distinct
4747
// debug_urls to make sure that the streamed or dumped events do not overlap
4848
// among the invocations.
49-
// TODO: More visible documentation of this in g3docs.
49+
// TODO(cais): More visible documentation of this in g3docs.
5050
repeated string debug_urls = 4;
5151

5252
// Do not error out if debug op creation fails (e.g., due to dtype

‎navi/navi/proto/tensorflow/core/protobuf/debug_event.proto

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ option java_package = "org.tensorflow.util";
1212
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
1313

1414
// Available modes for extracting debugging information from a Tensor.
15-
// TODO: Document the detailed column names and semantics in a separate
15+
// TODO(cais): Document the detailed column names and semantics in a separate
1616
// markdown file once the implementation settles.
1717
enum TensorDebugMode {
1818
UNSPECIFIED = 0;
@@ -223,7 +223,7 @@ message DebuggedDevice {
223223
// A debugger-generated ID for the device. Guaranteed to be unique within
224224
// the scope of the debugged TensorFlow program, including single-host and
225225
// multi-host settings.
226-
// TODO: Test the uniqueness guarantee in multi-host settings.
226+
// TODO(cais): Test the uniqueness guarantee in multi-host settings.
227227
int32 device_id = 2;
228228
}
229229

@@ -264,7 +264,7 @@ message Execution {
264264
// field with the DebuggedDevice messages.
265265
repeated int32 output_tensor_device_ids = 9;
266266

267-
// TODO support, add more fields
267+
// TODO(cais): When backporting to V1 Session.run() support, add more fields
268268
// such as fetches and feeds.
269269
}
270270

‎navi/navi/proto/tensorflow/core/protobuf/distributed_runtime_payloads.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
77

88
// Used to serialize and transmit tensorflow::Status payloads through
99
// grpc::Status `error_details` since grpc::Status lacks payload API.
10-
// TODO: Use GRPC API once supported.
10+
// TODO(b/204231601): Use GRPC API once supported.
1111
message GrpcPayloadContainer {
1212
map<string, bytes> payloads = 1;
1313
}

‎navi/navi/proto/tensorflow/core/protobuf/eager_service.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ message WaitQueueDoneRequest {
172172
}
173173

174174
message WaitQueueDoneResponse {
175-
// TODO: Consider adding NodeExecStats here to be able to
175+
// TODO(nareshmodi): Consider adding NodeExecStats here to be able to
176176
// propagate some stats.
177177
}
178178

‎navi/navi/proto/tensorflow/core/protobuf/master.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ message ExtendSessionRequest {
9494
}
9595

9696
message ExtendSessionResponse {
97-
// TODO: Return something about the operation?
97+
// TODO(mrry): Return something about the operation?
9898

9999
// The new version number for the extended graph, to be used in the next call
100100
// to ExtendSession.

‎navi/navi/proto/tensorflow/core/protobuf/saved_object_graph.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ message SavedBareConcreteFunction {
176176
// allows the ConcreteFunction to be called with nest structure inputs. This
177177
// field may not be populated. If this field is absent, the concrete function
178178
// can only be called with flat inputs.
179-
// TODO: support calling saved ConcreteFunction with structured
179+
// TODO(b/169361281): support calling saved ConcreteFunction with structured
180180
// inputs in C++ SavedModel API.
181181
FunctionSpec function_spec = 4;
182182
}

‎navi/navi/proto/tensorflow/core/protobuf/tensor_bundle.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
1717

1818
// Special header that is associated with a bundle.
1919
//
20-
// TODO: maybe in the future, we can add information about
20+
// TODO(zongheng,zhifengc): maybe in the future, we can add information about
2121
// which binary produced this checkpoint, timestamp, etc. Sometime, these can be
2222
// valuable debugging information. And if needed, these can be used as defensive
2323
// information ensuring reader (binary version) of the checkpoint and the writer

‎navi/navi/proto/tensorflow/core/protobuf/worker.proto

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ message DeregisterGraphRequest {
188188
}
189189

190190
message DeregisterGraphResponse {
191-
// TODO: Optionally add summary stats for the graph.
191+
// TODO(mrry): Optionally add summary stats for the graph.
192192
}
193193

194194
////////////////////////////////////////////////////////////////////////////////
@@ -294,7 +294,7 @@ message RunGraphResponse {
294294

295295
// If the request asked for execution stats, the cost graph, or the partition
296296
// graphs, these are returned here.
297-
// TODO: Package these in a RunMetadata instead.
297+
// TODO(suharshs): Package these in a RunMetadata instead.
298298
StepStats step_stats = 2;
299299
CostGraphDef cost_graph = 3;
300300
repeated GraphDef partition_graph = 4;

‎navi/navi/proto/tensorflow_serving/apis/logging.proto

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ message LogMetadata {
1313
SamplingConfig sampling_config = 2;
1414
// List of tags used to load the relevant MetaGraphDef from SavedModel.
1515
repeated string saved_model_tags = 3;
16-
// TODO: Add more metadata as mentioned in the bug.
16+
// TODO(b/33279154): Add more metadata as mentioned in the bug.
1717
}

‎navi/navi/proto/tensorflow_serving/config/file_system_storage_path_source.proto

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ message FileSystemStoragePathSourceConfig {
5858

5959
// A single servable name/base_path pair to monitor.
6060
// DEPRECATED: Use 'servables' instead.
61-
// TODO: Stop using these fields, and ultimately remove them here.
61+
// TODO(b/30898016): Stop using these fields, and ultimately remove them here.
6262
string servable_name = 1 [deprecated = true];
6363
string base_path = 2 [deprecated = true];
6464

@@ -76,7 +76,7 @@ message FileSystemStoragePathSourceConfig {
7676
// check for a version to appear later.)
7777
// DEPRECATED: Use 'servable_versions_always_present' instead, which includes
7878
// this behavior.
79-
// TODO: Remove 2019-10-31 or later.
79+
// TODO(b/30898016): Remove 2019-10-31 or later.
8080
bool fail_if_zero_versions_at_startup = 4 [deprecated = true];
8181

8282
// If true, the servable is always expected to exist on the underlying

‎navi/navi/proto/tensorflow_serving/config/model_server_config.proto

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import "tensorflow_serving/config/logging_config.proto";
99
option cc_enable_arenas = true;
1010

1111
// The type of model.
12-
// TODO: DEPRECATED.
12+
// TODO(b/31336131): DEPRECATED.
1313
enum ModelType {
1414
MODEL_TYPE_UNSPECIFIED = 0 [deprecated = true];
1515
TENSORFLOW = 1 [deprecated = true];
@@ -31,7 +31,7 @@ message ModelConfig {
3131
string base_path = 2;
3232

3333
// Type of model.
34-
// TODO: DEPRECATED. Please use 'model_platform' instead.
34+
// TODO(b/31336131): DEPRECATED. Please use 'model_platform' instead.
3535
ModelType model_type = 3 [deprecated = true];
3636

3737
// Type of model (e.g. "tensorflow").

‎navi/navi/src/bootstrap.rs

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use anyhow::Result;
22
use log::{info, warn};
3+
use x509_parser::{prelude::{parse_x509_pem}, parse_x509_certificate};
34
use std::collections::HashMap;
45
use tokio::time::Instant;
56
use tonic::{
@@ -27,6 +28,7 @@ use crate::cli_args::{ARGS, INPUTS, OUTPUTS};
2728
use crate::metrics::{
2829
NAVI_VERSION, NUM_PREDICTIONS, NUM_REQUESTS_FAILED, NUM_REQUESTS_FAILED_BY_MODEL,
2930
NUM_REQUESTS_RECEIVED, NUM_REQUESTS_RECEIVED_BY_MODEL, RESPONSE_TIME_COLLECTOR,
31+
CERT_EXPIRY_EPOCH
3032
};
3133
use crate::predict_service::{Model, PredictService};
3234
use crate::tf_proto::tensorflow_serving::model_spec::VersionChoice::Version;
@@ -233,6 +235,12 @@ impl<T: Model> PredictionService for PredictService<T> {
233235
}
234236
}
235237

238+
// A function that takes a timestamp as input and returns a ticker stream
239+
fn report_expiry(expiry_time: i64) {
240+
info!("Certificate expires at epoch: {:?}", expiry_time);
241+
CERT_EXPIRY_EPOCH.set(expiry_time as i64);
242+
}
243+
236244
pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
237245
info!("package: {}, version: {}, args: {:?}", NAME, VERSION, *ARGS);
238246
//we follow SemVer. So here we assume MAJOR.MINOR.PATCH
@@ -249,6 +257,7 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
249257
);
250258
}
251259

260+
252261
tokio::runtime::Builder::new_multi_thread()
253262
.thread_name("async worker")
254263
.worker_threads(ARGS.num_worker_threads)
@@ -266,6 +275,21 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
266275
let mut builder = if ARGS.ssl_dir.is_empty() {
267276
Server::builder()
268277
} else {
278+
// Read the pem file as a string
279+
let pem_str = std::fs::read_to_string(format!("{}/server.crt", ARGS.ssl_dir)).unwrap();
280+
let res = parse_x509_pem(&pem_str.as_bytes());
281+
match res {
282+
Ok((rem, pem_2)) => {
283+
assert!(rem.is_empty());
284+
assert_eq!(pem_2.label, String::from("CERTIFICATE"));
285+
let res_x509 = parse_x509_certificate(&pem_2.contents);
286+
info!("Certificate label: {}", pem_2.label);
287+
assert!(res_x509.is_ok());
288+
report_expiry(res_x509.unwrap().1.validity().not_after.timestamp());
289+
},
290+
_ => panic!("PEM parsing failed: {:?}", res),
291+
}
292+
269293
let key = tokio::fs::read(format!("{}/server.key", ARGS.ssl_dir))
270294
.await
271295
.expect("can't find key file");
@@ -281,7 +305,7 @@ pub fn bootstrap<T: Model>(model_factory: ModelFactory<T>) -> Result<()> {
281305
let identity = Identity::from_pem(pem.clone(), key);
282306
let client_ca_cert = Certificate::from_pem(pem.clone());
283307
let tls = ServerTlsConfig::new()
284-
.identity(identity)
308+
.identity(identity)
285309
.client_ca_root(client_ca_cert);
286310
Server::builder()
287311
.tls_config(tls)

‎navi/navi/src/metrics.rs

+7
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ lazy_static! {
171171
&["model_name"]
172172
)
173173
.expect("metric can be created");
174+
pub static ref CERT_EXPIRY_EPOCH: IntGauge =
175+
IntGauge::new(":navi:cert_expiry_epoch", "Timestamp when the current cert expires")
176+
.expect("metric can be created");
174177
}
175178

176179
pub fn register_custom_metrics() {
@@ -249,6 +252,10 @@ pub fn register_custom_metrics() {
249252
REGISTRY
250253
.register(Box::new(CONVERTER_TIME_COLLECTOR.clone()))
251254
.expect("collector can be registered");
255+
REGISTRY
256+
.register(Box::new(CERT_EXPIRY_EPOCH.clone()))
257+
.expect("collector can be registered");
258+
252259
}
253260

254261
pub fn register_dynamic_metrics(c: &HistogramVec) {

‎navi/navi/src/onnx_model.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ pub mod onnx {
189189
&version,
190190
reporting_feature_ids,
191191
Some(metrics::register_dynamic_metrics),
192-
)),
192+
)?),
193193
};
194194
onnx_model.warmup()?;
195195
Ok(onnx_model)

‎navi/navi/src/predict_service.rs

+17-21
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use serde_json::{self, Value};
2424

2525
pub trait Model: Send + Sync + Display + Debug + 'static {
2626
fn warmup(&self) -> Result<()>;
27-
//TODO: refactor this to return Vec<Vec<TensorScores>>, i.e.
27+
//TODO: refactor this to return vec<vec<TensorScores>>, i.e.
2828
//we have the underlying runtime impl to split the response to each client.
2929
//It will eliminate some inefficient memory copy in onnx_model.rs as well as simplify code
3030
fn do_predict(
@@ -222,8 +222,8 @@ impl<T: Model> PredictService<T> {
222222
.map(|b| b.parse().unwrap())
223223
.collect::<Vec<u64>>();
224224
let no_msg_wait_millis = *batch_time_out_millis.iter().min().unwrap();
225-
let mut all_model_predictors =
226-
ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS>::new();
225+
let mut all_model_predictors: ArrayVec::<ArrayVec<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>, MAX_NUM_MODELS> =
226+
(0 ..MAX_NUM_MODELS).map( |_| ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new()).collect();
227227
loop {
228228
let msg = rx.try_recv();
229229
let no_more_msg = match msg {
@@ -272,27 +272,23 @@ impl<T: Model> PredictService<T> {
272272
queue_reset_ts: Instant::now(),
273273
queue_earliest_rq_ts: Instant::now(),
274274
};
275-
if idx < all_model_predictors.len() {
276-
metrics::NEW_MODEL_SNAPSHOT
277-
.with_label_values(&[&MODEL_SPECS[idx]])
278-
.inc();
275+
assert!(idx < all_model_predictors.len());
276+
metrics::NEW_MODEL_SNAPSHOT
277+
.with_label_values(&[&MODEL_SPECS[idx]])
278+
.inc();
279279

280+
//we can do this since the vector is small
281+
let predictors = &mut all_model_predictors[idx];
282+
if predictors.len() == 0 {
283+
info!("now we serve new model: {}", predictor.model);
284+
}
285+
else {
280286
info!("now we serve updated model: {}", predictor.model);
281-
//we can do this since the vector is small
282-
let predictors = &mut all_model_predictors[idx];
283-
if predictors.len() == ARGS.versions_per_model {
284-
predictors.remove(predictors.len() - 1);
285-
}
286-
predictors.insert(0, predictor);
287-
} else {
288-
info!("now we serve new model: {:}", predictor.model);
289-
let mut predictors =
290-
ArrayVec::<BatchPredictor<T>, MAX_VERSIONS_PER_MODEL>::new();
291-
predictors.push(predictor);
292-
all_model_predictors.push(predictors);
293-
//check the invariant that we always push the last model to the end
294-
assert_eq!(all_model_predictors.len(), idx + 1)
295287
}
288+
if predictors.len() == ARGS.versions_per_model {
289+
predictors.remove(predictors.len() - 1);
290+
}
291+
predictors.insert(0, predictor);
296292
false
297293
}
298294
Err(TryRecvError::Empty) => true,

‎navi/segdense/src/error.rs

+33-23
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,49 @@ use std::fmt::Display;
55
*/
66
#[derive(Debug)]
77
pub enum SegDenseError {
8-
IoError(std::io::Error),
9-
Json(serde_json::Error),
10-
JsonMissingRoot,
11-
JsonMissingObject,
12-
JsonMissingArray,
13-
JsonArraySize,
14-
JsonMissingInputFeature,
8+
IoError(std::io::Error),
9+
Json(serde_json::Error),
10+
JsonMissingRoot,
11+
JsonMissingObject,
12+
JsonMissingArray,
13+
JsonArraySize,
14+
JsonMissingInputFeature,
1515
}
1616

1717
impl Display for SegDenseError {
18-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19-
match self {
20-
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
21-
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
22-
SegDenseError::JsonMissingRoot => write!(f, "{}", "SegDense JSON: Root Node note found!"),
23-
SegDenseError::JsonMissingObject => write!(f, "{}", "SegDense JSON: Object note found!"),
24-
SegDenseError::JsonMissingArray => write!(f, "{}", "SegDense JSON: Array Node note found!"),
25-
SegDenseError::JsonArraySize => write!(f, "{}", "SegDense JSON: Array size not as expected!"),
26-
SegDenseError::JsonMissingInputFeature => write!(f, "{}", "SegDense JSON: Missing input feature!"),
18+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19+
match self {
20+
SegDenseError::IoError(io_error) => write!(f, "{}", io_error),
21+
SegDenseError::Json(serde_json) => write!(f, "{}", serde_json),
22+
SegDenseError::JsonMissingRoot => {
23+
write!(f, "{}", "SegDense JSON: Root Node note found!")
24+
}
25+
SegDenseError::JsonMissingObject => {
26+
write!(f, "{}", "SegDense JSON: Object note found!")
27+
}
28+
SegDenseError::JsonMissingArray => {
29+
write!(f, "{}", "SegDense JSON: Array Node note found!")
30+
}
31+
SegDenseError::JsonArraySize => {
32+
write!(f, "{}", "SegDense JSON: Array size not as expected!")
33+
}
34+
SegDenseError::JsonMissingInputFeature => {
35+
write!(f, "{}", "SegDense JSON: Missing input feature!")
36+
}
37+
}
2738
}
28-
}
2939
}
3040

3141
impl std::error::Error for SegDenseError {}
3242

3343
impl From<std::io::Error> for SegDenseError {
34-
fn from(err: std::io::Error) -> Self {
35-
SegDenseError::IoError(err)
36-
}
44+
fn from(err: std::io::Error) -> Self {
45+
SegDenseError::IoError(err)
46+
}
3747
}
3848

3949
impl From<serde_json::Error> for SegDenseError {
40-
fn from(err: serde_json::Error) -> Self {
41-
SegDenseError::Json(err)
42-
}
50+
fn from(err: serde_json::Error) -> Self {
51+
SegDenseError::Json(err)
52+
}
4353
}

‎navi/segdense/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pub mod error;
2-
pub mod segdense_transform_spec_home_recap_2022;
32
pub mod mapper;
4-
pub mod util;
3+
pub mod segdense_transform_spec_home_recap_2022;
4+
pub mod util;

‎navi/segdense/src/main.rs

+11-12
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@ use segdense::error::SegDenseError;
55
use segdense::util;
66

77
fn main() -> Result<(), SegDenseError> {
8-
env_logger::init();
9-
let args: Vec<String> = env::args().collect();
10-
11-
let schema_file_name: &str = if args.len() == 1 {
12-
"json/compact.json"
13-
} else {
14-
&args[1]
15-
};
8+
env_logger::init();
9+
let args: Vec<String> = env::args().collect();
1610

17-
let json_str = fs::read_to_string(schema_file_name)?;
11+
let schema_file_name: &str = if args.len() == 1 {
12+
"json/compact.json"
13+
} else {
14+
&args[1]
15+
};
1816

19-
util::safe_load_config(&json_str)?;
17+
let json_str = fs::read_to_string(schema_file_name)?;
2018

21-
Ok(())
22-
}
19+
util::safe_load_config(&json_str)?;
2320

21+
Ok(())
22+
}

‎navi/segdense/src/mapper.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ pub struct FeatureMapper {
1919
impl FeatureMapper {
2020
pub fn new() -> FeatureMapper {
2121
FeatureMapper {
22-
map: HashMap::new()
22+
map: HashMap::new(),
2323
}
2424
}
2525
}
2626

2727
pub trait MapWriter {
28-
fn set(&mut self, feature_id: i64, info: FeatureInfo);
28+
fn set(&mut self, feature_id: i64, info: FeatureInfo);
2929
}
3030

3131
pub trait MapReader {

‎navi/segdense/src/segdense_transform_spec_home_recap_2022.rs

-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ pub struct ComplexFeatureTypeTransformSpec {
164164
pub tensor_shape: Vec<i64>,
165165
}
166166

167-
168167
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
169168
#[serde(rename_all = "camelCase")]
170169
pub struct InputFeatureMapRecord {

‎navi/segdense/src/util.rs

+26-31
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1+
use log::debug;
12
use std::fs;
2-
use log::{debug};
33

4-
use serde_json::{Value, Map};
4+
use serde_json::{Map, Value};
55

66
use crate::error::SegDenseError;
7-
use crate::mapper::{FeatureMapper, FeatureInfo, MapWriter};
7+
use crate::mapper::{FeatureInfo, FeatureMapper, MapWriter};
88
use crate::segdense_transform_spec_home_recap_2022::{self as seg_dense, InputFeature};
99

10-
pub fn load_config(file_name: &str) -> seg_dense::Root {
11-
let json_str = fs::read_to_string(file_name).expect(
12-
&format!("Unable to load segdense file {}", file_name));
13-
let seg_dense_config = parse(&json_str).expect(
14-
&format!("Unable to parse segdense file {}", file_name));
15-
return seg_dense_config;
10+
pub fn load_config(file_name: &str) -> Result<seg_dense::Root, SegDenseError> {
11+
let json_str = fs::read_to_string(file_name)?;
12+
// &format!("Unable to load segdense file {}", file_name));
13+
let seg_dense_config = parse(&json_str)?;
14+
// &format!("Unable to parse segdense file {}", file_name));
15+
Ok(seg_dense_config)
1616
}
1717

1818
pub fn parse(json_str: &str) -> Result<seg_dense::Root, SegDenseError> {
1919
let root: seg_dense::Root = serde_json::from_str(json_str)?;
20-
return Ok(root);
20+
Ok(root)
2121
}
2222

2323
/**
@@ -44,15 +44,8 @@ pub fn safe_load_config(json_str: &str) -> Result<FeatureMapper, SegDenseError>
4444
load_from_parsed_config(root)
4545
}
4646

47-
pub fn load_from_parsed_config_ref(root: &seg_dense::Root) -> FeatureMapper {
48-
load_from_parsed_config(root.clone()).unwrap_or_else(
49-
|error| panic!("Error loading all_config.json - {}", error))
50-
}
51-
5247
// Perf note : make 'root' un-owned
53-
pub fn load_from_parsed_config(root: seg_dense::Root) ->
54-
Result<FeatureMapper, SegDenseError> {
55-
48+
pub fn load_from_parsed_config(root: seg_dense::Root) -> Result<FeatureMapper, SegDenseError> {
5649
let v = root.input_features_map;
5750

5851
// Do error check
@@ -86,27 +79,30 @@ pub fn load_from_parsed_config(root: seg_dense::Root) ->
8679
Some(info) => {
8780
debug!("{:?}", info);
8881
fm.set(feature_id, info)
89-
},
82+
}
9083
None => (),
9184
}
9285
}
9386

9487
Ok(fm)
9588
}
9689
#[allow(dead_code)]
97-
fn add_feature_info_to_mapper(feature_mapper: &mut FeatureMapper, input_features: &Vec<InputFeature>) {
90+
fn add_feature_info_to_mapper(
91+
feature_mapper: &mut FeatureMapper,
92+
input_features: &Vec<InputFeature>,
93+
) {
9894
for input_feature in input_features.iter() {
99-
let feature_id = input_feature.feature_id;
100-
let feature_info = to_feature_info(input_feature);
101-
102-
match feature_info {
103-
Some(info) => {
104-
debug!("{:?}", info);
105-
feature_mapper.set(feature_id, info)
106-
},
107-
None => (),
95+
let feature_id = input_feature.feature_id;
96+
let feature_info = to_feature_info(input_feature);
97+
98+
match feature_info {
99+
Some(info) => {
100+
debug!("{:?}", info);
101+
feature_mapper.set(feature_id, info)
108102
}
103+
None => (),
109104
}
105+
}
110106
}
111107

112108
pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<FeatureInfo> {
@@ -139,7 +135,7 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
139135
2 => 0,
140136
3 => 2,
141137
_ => -1,
142-
}
138+
},
143139
};
144140

145141
if input_feature.index < 0 {
@@ -156,4 +152,3 @@ pub fn to_feature_info(input_feature: &seg_dense::InputFeature) -> Option<Featur
156152
index_within_tensor: input_feature.index,
157153
})
158154
}
159-

0 commit comments

Comments
 (0)
Please sign in to comment.