| namespace adam_atan2 { | |
| void adam_atan2_cuda_impl_( | |
| std::vector<at::Tensor> params, | |
| std::vector<at::Tensor> grads, | |
| std::vector<at::Tensor> exp_avgs, | |
| std::vector<at::Tensor> exp_avg_sqs, | |
| std::vector<at::Tensor> state_steps, | |
| const double lr, | |
| const double beta1, | |
| const double beta2, | |
| const double weight_decay); | |
| } | |
| // void adam_atan2_cuda_impl_( | |
| // std::vector<at::Tensor, std::allocator<at::Tensor> > params, | |
| // std::vector<at::Tensor, std::allocator<at::Tensor> > grads, | |
| // std::vector<at::Tensor, std::allocator<at::Tensor> > exp_avgs, | |
| // std::vector<at::Tensor, std::allocator<at::Tensor> > exp_avg_sqs, | |
| // std::vector<at::Tensor, std::allocator<at::Tensor> > state_steps, | |
| // const double lr, | |
| // const double beta1, | |
| // const double beta2, | |
| // const double weight_decay); |