@@ -103,24 +103,22 @@ void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out)
103103 const scalar_t *a_ptr = a_contiguous.const_data_ptr <scalar_t >();
104104 scalar_t *out_ptr = padded_out.mutable_data_ptr <scalar_t >();
105105
106- // at::parallel_for(0, B, 1, [&](int64_t start, int64_t end)
107- // {
108- #pragma omp parallel for
109- for (auto b = 0 ; b < B; b++)
110- {
111- auto out_offset = b * (T + order) + order;
112- auto a_offset = b * T * order;
113- for (int64_t t = 0 ; t < T; t++)
106+ at::parallel_for (0 , B, 1 , [&](int64_t start, int64_t end)
107+ {
108+ for (auto b = start; b < end; b++)
114109 {
115- scalar_t y = out_ptr[out_offset + t];
116- for (int64_t i = 0 ; i < order; i++)
110+ auto out_offset = out_ptr + b * (T + order) + order;
111+ auto a_offset = a_ptr + b * T * order;
112+ for (int64_t t = 0 ; t < T; t++)
117113 {
118- y -= a_ptr[a_offset + t * order + i] *
119- out_ptr[out_offset + t - i - 1 ];
114+ scalar_t y = out_offset[t];
115+ for (int64_t i = 0 ; i < order; i++)
116+ {
117+ y -= a_offset[t * order + i] * out_offset [t - i - 1 ];
118+ }
119+ out_offset[t] = y;
120120 }
121- out_ptr[out_offset + t] = y;
122- }
123- };
121+ }; });
124122}
125123
126124at::Tensor scan_cpu_wrapper (const at::Tensor &input, const at::Tensor &weights,
0 commit comments