diff --git a/graph_net/tools/generate_subgraph_dataset.sh b/graph_net/tools/generate_subgraph_dataset.sh index 3106ffa84..a56ecaaf6 100755 --- a/graph_net/tools/generate_subgraph_dataset.sh +++ b/graph_net/tools/generate_subgraph_dataset.sh @@ -54,7 +54,7 @@ function generate_op_names() { --model-path-list $model_list \ --handler-config=$(base64 -w 0 <>> [2] Generate split points for samples in ${model_list}." echo ">>> MIN_SEQ_OPS: ${MIN_SEQ_OPS}, MAX_SEQ_OPS: ${MAX_SEQ_OPS}" echo ">>>" - python3 -m graph_net.torch.typical_sequence_split_points \ - --model-list "$model_list" \ - --op-names-path-prefix "${OP_NAMES_OUTPUT_DIR}" \ - --device "cuda" \ - --window-size 64 \ - --fold-policy default \ - --fold-times 16 \ - --min-seq-ops ${MIN_SEQ_OPS} \ - --max-seq-ops ${MAX_SEQ_OPS} \ - --output-dir "$DECOMPOSE_WORKSPACE" \ - --subgraph-ranges-file-name "typical_subgraph_ranges.json" \ - --subgraph-ranges-json "$DECOMPOSE_WORKSPACE/subgraph_ranges_${OP_RANGE}ops.json" \ - --output-json "$DECOMPOSE_WORKSPACE/split_results_${OP_RANGE}ops.json" + python3 -m graph_net.apply_sample_pass \ + --model-path-list $model_list \ + --sample-pass-file-path $GRAPH_NET_ROOT/graph_net/torch/sample_pass/typical_sequence_split_points.py \ + --sample-pass-class-name TypicalSequenceSplitPointsGenerator \ + --sample-pass-config=$(base64 -w 0 <