Skip to content

Conversation

@WHoutstanding
Copy link
Contributor

PR Category

other

Description

Fix dtype_generalizer as a subclass of SamplePass.

@paddle-bot
Copy link

paddle-bot bot commented Jan 12, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Jan 12, 2026

def declare_config(
self,
dtype_list: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这看起来应该是list[str]呀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到


def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(
rel_model_path, search_file_name="op_names.txt"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个search_file_name肯定搞错了。

它表示如果目标目录下找到了op_names.txt文件,就表示已处理。这个op_names明显是op_names_extractor的逻辑,你所生产的逻辑肯定不是这个。你再看看代码,是不是应该写model.py?

Copy link
Contributor Author

@WHoutstanding WHoutstanding Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,我再通过调试理解到位ResumableSampleMixin的代码逻辑

def __call__(self, model_path: str) -> List[str]:
def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(
rel_model_path, search_file_name="op_names.txt"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也一样。serach_file_name肯定搞错了

def declare_config(
self,
dtype_list: str,
dtype_list: list,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype_list是list[str], 但是class SamplePass的_check_config_declaration_parameters方法会检查type必须为list,所以这里修改为list,而不是list[str]

def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(
rel_model_path, search_file_name="op_names.txt"
rel_model_path, search_file_name="model.py"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class ResumableSamplePassMixin(SamplePassMixin)的naive_sample_handled方法是通过sample_path是否存在判断sample是否被处理,如果sample path存在,则要搜索生成的文件,所以这里搜索的应该是model.py。
问题:
naive_sample_handled()方法在dtype_generalizer.py的作用?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive_sample_handled是基础组件,它并没有和dtype_generalizer有耦合。
naive_sample_handled内部的逻辑非常单一,用于检测当前样本是否已经处理过。

pass

def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不能调用naive_sample_handled。
其他大部分的sample pass。都是转换逻辑,不会修改原始的样本。但这个pass不一样,这个pass需要修改graph_net.json。
你先看这个类的resume方法里到底更新了graph_net.json的什么字段。然后在当前的sample_handled函数里,如果graph_net.json里的相关字段已经有值了,就认为该sample已经被处理过了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到


def __call__(self, model_path: str) -> List[str]:
def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是不能调用naive_sample_handled。

naive_sample_handled只能处理一个输入sample对应一个输出sample的情形。但你这里需要处理一个输入sample对应多个输出sample的情形。
你可以看看torch/sample_pass/subgraph_generator:SubgraphGenerator.sample_handled方法,它所调用的_has_enough_subgraphs大概是你需要的写的逻辑,不能完全照搬

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到


def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
dst_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码逻辑:通过判断GraphNet/samples/*/graph_net.json的data_type_generalization_passes是否有值来判断sample是否已被处理


def sample_handled(self, rel_model_path: str) -> bool:
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
model_path = Path(self.config["model_path_prefix"]) / rel_model_path
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码逻辑:例如,在/tmp/dtype_gen_samples生成了resnet18_float16和resnet18_bfloat16,通过判断resnet18_float16和resnet18_bfloat16含有model.py的数量是否等于dtype_pass_names的数量来判断样本是否被处理

@lixinqi lixinqi merged commit 7e0ff43 into PaddlePaddle:develop Jan 14, 2026
2 checks passed
@WHoutstanding WHoutstanding deleted the fix_dtype_generalizer branch January 14, 2026 15:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants