2525
2626class FourierFTLayer (BaseTunerLayer ):
2727 # All names of layers that may contain (trainable) adapter weights
28- adapter_layer_names = ("fourierft_spectrum" ,)
28+ adapter_layer_names = ("fourierft_spectrum" , "fourierft_scaling" )
2929 # All names of other parameters that may contain adapter-related parameters
30- other_param_names = ("fourierft_n_frequency" , "fourierft_scaling" , " fourierft_random_loc_seed" )
30+ other_param_names = ("fourierft_n_frequency" , "fourierft_random_loc_seed" )
3131
32- def __init__ (self , base_layer : nn .Module , alpha , ** kwargs ) -> None :
32+ def __init__ (self , base_layer : nn .Module , ** kwargs ) -> None :
3333 self .base_layer = base_layer
3434 self .fourierft_n_frequency = {}
35- self .fourierft_scaling = {}
35+ self .fourierft_scaling = nn . ParameterDict ({})
3636 self .fourierft_spectrum = nn .ParameterDict ({})
3737 self .indices = {}
3838 self .fourierft_random_loc_seed = {}
@@ -55,7 +55,7 @@ def __init__(self, base_layer: nn.Module, alpha, **kwargs) -> None:
5555 raise ValueError (f"Unsupported layer type { type (base_layer )} " )
5656
5757 def update_layer (
58- self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , inference_mode : bool = False , ** kwargs
58+ self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling , inference_mode : bool = False , ** kwargs
5959 ):
6060 if n_frequency <= 0 :
6161 raise ValueError (f"`n_frequency` should be a positive integer value but the value passed is { n_frequency } " )
@@ -73,7 +73,8 @@ def update_layer(
7373 self .indices [adapter_name ] = torch .stack (
7474 [self .indices [adapter_name ] // self .in_features , self .indices [adapter_name ] % self .in_features ], dim = 0
7575 )
76- self .fourierft_scaling [adapter_name ] = scaling
76+ self .fourierft_scaling [adapter_name ] = nn .Parameter (torch .tensor (scaling , dtype = torch .float32 ))
77+ self .fourierft_scaling [adapter_name ].requires_grad = dynamic_scaling
7778 # Actual trainable parameters
7879 self .fourierft_spectrum [adapter_name ] = nn .Parameter (torch .randn (n_frequency ), requires_grad = True )
7980
@@ -107,21 +108,22 @@ def __init__(
107108 n_frequency : int = 1000 ,
108109 alpha : float = None ,
109110 scaling : float = 150.0 ,
111+ dynamic_scaling : bool = False ,
110112 fan_in_fan_out : bool = False , # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
111113 init_weights : Union [bool , str ] = False ,
112114 random_loc_seed : int = 777 ,
113115 ** kwargs ,
114116 ) -> None :
115117 super ().__init__ ()
116- FourierFTLayer .__init__ (self , base_layer , alpha , ** kwargs )
118+ FourierFTLayer .__init__ (self , base_layer , ** kwargs )
117119
118120 # apply alpha patch
119121 if alpha :
120122 n_frequency = int (alpha * self .in_features * self .out_features )
121123
122124 self .fan_in_fan_out = fan_in_fan_out
123125 self ._active_adapter = adapter_name
124- self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed )
126+ self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling )
125127
126128 def merge (self , safe_merge : bool = False , adapter_names : Optional [list [str ]] = None ) -> None :
127129 """
@@ -210,29 +212,30 @@ def __init__(
210212 n_frequency : int = 1000 ,
211213 alpha : float = None ,
212214 scaling : float = 150.0 ,
215+ dynamic_scaling : bool = False ,
213216 fan_in_fan_out : bool = False , # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
214217 init_weights : Union [bool , str ] = False ,
215218 random_loc_seed : int = 777 ,
216219 ** kwargs ,
217220 ) -> None :
218221 super ().__init__ ()
219- FourierFTLayer .__init__ (self , base_layer , alpha , ** kwargs )
222+ FourierFTLayer .__init__ (self , base_layer , ** kwargs )
220223
221224 # apply alpha patch
222225 if alpha :
223226 n_frequency = int (alpha * self .in_features * self .out_features )
224-
227+
225228 self .fan_in_fan_out = fan_in_fan_out
226229 self ._active_adapter = adapter_name
227230 self .kW = base_layer .kernel_size [0 ]
228231 self .kH = base_layer .kernel_size [1 ]
229232 self .stride = base_layer .stride
230233 self .padding = base_layer .padding
231- self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed )
234+ self .update_layer (adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling )
232235
233236
234237 def update_layer (
235- self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , inference_mode : bool = False , ** kwargs
238+ self , adapter_name , n_frequency , scaling , init_weights , random_loc_seed , dynamic_scaling , inference_mode : bool = False , ** kwargs
236239 ):
237240 if n_frequency <= 0 :
238241 raise ValueError (f"`n_frequency` should be a positive integer value but the value passed is { n_frequency } " )
@@ -241,6 +244,7 @@ def update_layer(
241244 f"`n_frequency` should be less than or equal to the product of the input and output dimensions "
242245 f"but the value passed is { n_frequency } and the product is { self .in_features * self .out_features } "
243246 )
247+
244248 self .fourierft_n_frequency [adapter_name ] = n_frequency
245249 self .fourierft_random_loc_seed [adapter_name ] = random_loc_seed
246250 self .indices [adapter_name ] = torch .randperm (
@@ -250,7 +254,8 @@ def update_layer(
250254 self .indices [adapter_name ] = torch .stack (
251255 [self .indices [adapter_name ] // self .in_features , self .indices [adapter_name ] % self .in_features ], dim = 0
252256 )
253- self .fourierft_scaling [adapter_name ] = scaling
257+ self .fourierft_scaling [adapter_name ] = nn .Parameter (torch .tensor (scaling , dtype = torch .float32 ))
258+ self .fourierft_scaling [adapter_name ].requires_grad = dynamic_scaling
254259 # Actual trainable parameters
255260 self .fourierft_spectrum [adapter_name ] = nn .Parameter (torch .randn (n_frequency , self .kW , self .kH ), requires_grad = True )
256261
0 commit comments