-
Notifications
You must be signed in to change notification settings - Fork 45
Fix dtype_generalizer.py #543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype_generalizer.py #543
Conversation
|
Thanks for your contribution! |
|
|
||
| def declare_config( | ||
| self, | ||
| dtype_list: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这看起来应该是list[str]呀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
|
|
||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled( | ||
| rel_model_path, search_file_name="op_names.txt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个search_file_name肯定搞错了。
它表示如果目标目录下找到了op_names.txt文件,就表示已处理。这个op_names明显是op_names_extractor的逻辑,你所生产的逻辑肯定不是这个。你再看看代码,是不是应该写model.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到,我再通过调试理解到位ResumableSampleMixin的代码逻辑
| def __call__(self, model_path: str) -> List[str]: | ||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled( | ||
| rel_model_path, search_file_name="op_names.txt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也一样。serach_file_name肯定搞错了
| def declare_config( | ||
| self, | ||
| dtype_list: str, | ||
| dtype_list: list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype_list是list[str], 但是class SamplePass的_check_config_declaration_parameters方法会检查type必须为list,所以这里修改为list,而不是list[str]
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled( | ||
| rel_model_path, search_file_name="op_names.txt" | ||
| rel_model_path, search_file_name="model.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ResumableSamplePassMixin(SamplePassMixin)的naive_sample_handled方法是通过sample_path是否存在判断sample是否被处理,如果sample path存在,则要搜索生成的文件,所以这里搜索的应该是model.py。
问题:
naive_sample_handled()方法在dtype_generalizer.py的作用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naive_sample_handled是基础组件,它并没有和dtype_generalizer有耦合。
naive_sample_handled内部的逻辑非常单一,用于检测当前样本是否已经处理过。
| pass | ||
|
|
||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled(rel_model_path, search_file_name="model.py") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不能调用naive_sample_handled。
其他大部分的sample pass。都是转换逻辑,不会修改原始的样本。但这个pass不一样,这个pass需要修改graph_net.json。
你先看这个类的resume方法里到底更新了graph_net.json的什么字段。然后在当前的sample_handled函数里,如果graph_net.json里的相关字段已经有值了,就认为该sample已经被处理过了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
|
|
||
| def __call__(self, model_path: str) -> List[str]: | ||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled(rel_model_path, search_file_name="model.py") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还是不能调用naive_sample_handled。
naive_sample_handled只能处理一个输入sample对应一个输出sample的情形。但你这里需要处理一个输入sample对应多个输出sample的情形。
你可以看看torch/sample_pass/subgraph_generator:SubgraphGenerator.sample_handled方法,它所调用的_has_enough_subgraphs大概是你需要的写的逻辑,不能完全照搬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
|
|
||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled(rel_model_path, search_file_name="model.py") | ||
| dst_model_path = Path(self.config["model_path_prefix"]) / rel_model_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码逻辑:通过判断GraphNet/samples/*/graph_net.json的data_type_generalization_passes是否有值来判断sample是否已被处理
|
|
||
| def sample_handled(self, rel_model_path: str) -> bool: | ||
| return self.naive_sample_handled(rel_model_path, search_file_name="model.py") | ||
| model_path = Path(self.config["model_path_prefix"]) / rel_model_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码逻辑:例如,在/tmp/dtype_gen_samples生成了resnet18_float16和resnet18_bfloat16,通过判断resnet18_float16和resnet18_bfloat16含有model.py的数量是否等于dtype_pass_names的数量来判断样本是否被处理
PR Category
other
Description
Fix dtype_generalizer as a subclass of SamplePass.