/**
  ****************************(C) COPYRIGHT 2019 DJI****************************
  * @file       pid.c/h
  * @brief      pid
  * @note       
  * @history
  *  Version    Date            Author          Modification
  *  V1.0.0     Dec-26-2018     RM              1.
  *
  @verbatim
  ==============================================================================

  ==============================================================================
  @endverbatim
  ****************************(C) COPYRIGHT 2019 DJI****************************
  */

#include "pid.h"
#include "main.h"
#include "math.h"

#define LimitMax(input, max)   \
    {                          \
        if (input > max)       \
        {                      \
            input = max;       \
        }                      \
        else if (input < -max) \
        {                      \
            input = -max;      \
        }                      \
    }


/**
  * @brief          pid struct data init
  * @param[out]     pid: PID
  * @param[in]      mode: PID_POSITION;
  *                 PID_DELTA:PID
  * @param[in]      PID: 0: kp, 1: ki, 2:kd 3: i_limit 4: out_limit 5: d_cutoff_freq
  * @retval         none
  */
int8_t PID_init(pid_type_def *pid, uint8_t mode, const pid_param_t *param)
{
    if (pid == NULL || param == NULL)
    {
        return -1;
    }

  if (!isfinite(param->p)) return -1;
  if (!isfinite(param->i)) return -1;
  if (!isfinite(param->d)) return -1;
  if (!isfinite(param->i_limit)) return -1;
  if (!isfinite(param->out_limit)) return -1;
    pid->mode = mode;
    pid->param = param;
    pid->Dbuf[0] = pid->Dbuf[1] = pid->Dbuf[2] = 0.0f;
    pid->error[0] = pid->error[1] = pid->error[2] = pid->Pout = pid->Iout = pid->Dout = pid->out = 0.0f;
	
		return 0;
}

/**
  * @brief          pid calculate 
  * @param[out]     pid: PID struct data point
  * @param[in]      ref: feedback data 
  * @param[in]      set: set point
  * @retval         pid out
  */
/**
  * @brief          pid
  * @param[out]     pid: PID
  * @param[in]      ref: 
  * @param[in]      set: 
  * @retval         pid
  */
fp32 PID_calc(pid_type_def *pid, fp32 ref, fp32 set)
{
    if (pid == NULL)
    {
        return 0.0f;
    }

    pid->error[2] = pid->error[1];
    pid->error[1] = pid->error[0];
    pid->set = set;
    pid->fdb = ref;
    pid->error[0] = set - ref;
    if (pid->mode == PID_POSITION)
    {
        pid->Pout = pid->param->p * pid->error[0];
        pid->Iout += pid->param->i * pid->error[0];
        pid->Dbuf[2] = pid->Dbuf[1];
        pid->Dbuf[1] = pid->Dbuf[0];
        pid->Dbuf[0] = (pid->error[0] - pid->error[1]);
        pid->Dout = pid->param->d * pid->Dbuf[0];
        LimitMax((pid->Iout), pid->param->i_limit);
        pid->out = pid->Pout + pid->Iout + pid->Dout;
        LimitMax(pid->out, pid->param->out_limit);
    }
    else if (pid->mode == PID_DELTA)
    {
        pid->Pout = pid->param->p * (pid->error[0] - pid->error[1]);
        pid->Iout = pid->param->i * pid->error[0];
        pid->Dbuf[2] = pid->Dbuf[1];
        pid->Dbuf[1] = pid->Dbuf[0];
        pid->Dbuf[0] = (pid->error[0] - 2.0f * pid->error[1] + pid->error[2]);
        pid->Dout = pid->param->d* pid->Dbuf[0];
        pid->out += pid->Pout + pid->Iout + pid->Dout;
        LimitMax(pid->out, pid->param->out_limit);
    }
		else if (pid->mode == PID_POSITION_D)
		{
        pid->Pout = pid->param->p * pid->error[0];
        pid->Iout += pid->param->i * pid->error[0];
        pid->Dbuf[2] = pid->Dbuf[1];
        pid->Dbuf[1] = pid->Dbuf[0];
        pid->Dbuf[0] = (pid->error[0] - pid->error[1]);
			
        fp32 alpha = 0.3f; // Adjust this value based on your observations
        pid->Dout = alpha * (pid->param->d * pid->Dbuf[0]) + (1.0f - alpha) * pid->Dout;
			
        LimitMax((pid->Iout), pid->param->i_limit);
        pid->out = pid->Pout + pid->Iout + pid->Dout;
        LimitMax(pid->out, pid->param->out_limit);
    }
		
		if(isnan(pid->out))
		{
			PID_clear(pid);
		}
		
    return pid->out;
}

/**
  * @brief          pid out clear
  * @param[out]     pid: PID struct data point
  * @retval         none
  */
/**
  * @brief          pid 
  * @param[out]     pid: PID
  * @retval         none
  */
void PID_clear(pid_type_def *pid)
{
    if (pid == NULL)
    {
        return;
    }

    pid->error[0] = pid->error[1] = pid->error[2] = 0.0f;
    pid->Dbuf[0] = pid->Dbuf[1] = pid->Dbuf[2] = 0.0f;
    pid->out = pid->Pout = pid->Iout = pid->Dout = 0.0f;
    pid->fdb = pid->set = 0.0f;
}