@@ -72,16 +72,18 @@ def __init__(
7272 # set the observation and action space here
7373 self ._vocab_size = self .tokenizer .vocab_size
7474
75- self .observation_space = DictSpace ({
76- "input_encoded_pt" : spaces .Box (
77- low = 0 ,
78- high = self ._vocab_size ,
79- shape = (self ._max_text_length + self .max_steps ,),
80- ),
81- "input_attention_mask_pt" : spaces .Box (
82- low = 0 , high = 1 , shape = (self ._max_text_length + self .max_steps ,)
83- ),
84- })
75+ self .observation_space = DictSpace (
76+ {
77+ "input_encoded_pt" : spaces .Box (
78+ low = 0 ,
79+ high = self ._vocab_size ,
80+ shape = (self ._max_text_length + self .max_steps ,),
81+ ),
82+ "input_attention_mask_pt" : spaces .Box (
83+ low = 0 , high = 1 , shape = (self ._max_text_length + self .max_steps ,)
84+ ),
85+ }
86+ )
8587 self .action_space = Discrete (n = self ._vocab_size )
8688 # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency
8789
@@ -112,7 +114,6 @@ def __init__(
112114 self .reward_function = None
113115
114116 def set_reward (self , reward_fn = None ):
115-
116117 self .reward_function = reward_fn
117118
118119 def step_word (self , word : str ) -> Tuple [Dict [str , torch .tensor ], int , bool , dict ]:
0 commit comments