###########################################################################
# program: Bessel_lowpass_filter.py
# author: Tom Irvine
# version: 1.1
# date: September 12, 2013
# description:  Bessel two-pole lowpass filter
#               The input file must have two columns: time(sec) & amplitude
###########################################################################

from __future__ import print_function

from tompy import read_two_columns,signal_stats,sample_rate_check
from tompy import GetInteger2,GetInteger_n,WriteData2

from tompy import enter_float
from tompy import time_history_plot

from scipy.signal import lfilter
from numpy import zeros
from sys import stdin
from math import pi,tan,atan2,sqrt,log,ceil

import matplotlib.pyplot as plt

#*******************************************************************************

def Bessel_lowpass_filter_coefficients(fc,dt,scale):
    """    
    """   
    OM=tan(pi*fc*dt/scale)

    OM2=OM**2

    den=1+3*OM+3*OM2

    b0=3*OM2/den
    b1=2*b0
    b2=b0

    a1=2*(-1+3*OM2)/den
    a2=(1-3*OM+3*OM2)/den          
    
    return a1,a2,b0,b1,b2
    
#*******************************************************************************

def Bessel_transfer_function(fc,scale):
    """    
    """    
    
    fmin=fc/100;
    if(fmin>1.):
        fmin=1
    
    fmax=5*fc

    nf=1
    
    nf=int(ceil(48.*log(fmax/fmin)/log(2.)))   
    
    print ("\n nf = %d " % nf)

    f=zeros(nf,'f')
    H=zeros(nf,'complex')
    H_mag=zeros(nf,'f')
    H_phase=zeros(nf,'f')    
    
    f[0]=fmin
            
    for i in range(1,nf): 
        f[i]=f[i-1]*2**(1./48.)
    
    
    for i in range(0,nf):    
        s=(1J)*(scale*f[i]/fc)
        H[i]=3./(s**2+3*s+3)
        H_phase[i]=(180./pi)*atan2(H[i].imag,H[i].real)
        H_mag[i]=abs(H[i])
     
    return f,H_mag,H_phase
    
#*******************************************************************************
      
print ("The file must have two columns: time(sec) & amplitude")

ttime,y,num =read_two_columns()


print ("\nSelect Y-axis Label ")
print ("  1=Accel (G) ")
print ("  2=Pressure (psi) ")
print ("  3=none ")
print ("  4=other ")

alab=GetInteger_n(4)

if(alab==1):
    yaxislabel='Accel (G)'

if(alab==2):
    yaxislabel='Pressure (psi)'

if(alab==3):
    yaxislabel=' '
    
if(alab==4):    
    print ("Enter label")
    yaxislabel = stdin.readline()


sr,dt,mean,sd,rms,skew,kurtosis,dur=signal_stats(ttime,y,num)

sr,dt=sample_rate_check(ttime,y,num,sr,dt)


   
title_label='Lowpass Filtered Time History'
while(1):
    print("  ")    
    print(" The transfer function is -3 dB at the lowpass frequency. ")
    print("  ")
    print(" Enter lowpass frequency (Hz) ")
    fc=enter_float()
    if(fc<0.5*sr):
        break
    else:
        print("\n error: cutoff frequency must be < Nyquist frequency \n")
        
        
#*******************************************************************************        
 
scale=sqrt(3*(-1+sqrt(5))/2)

print ("\n Calculate transfer function ")

f,H_mag,H_phase=Bessel_transfer_function(fc,scale)

print ("\n Calculate filter coefficient ")

a1,a2,b0,b1,b2=Bessel_lowpass_filter_coefficients(fc,dt,scale)

print ("\n Apply filter ")

bc=[b0,b1,b2]

ac=[1,a1,a2]
    
yf=lfilter(bc, ac, y, axis=-1, zi=None)  

#*******************************************************************************

print ("\n Filter signal statistics ")
sr,dt,mean,sd,rms,skew,kurtosis,dur=signal_stats(ttime,yf,num)

print(" ")
print(" Write filtered data to file?  1=yes 2=no ")

iacc = GetInteger2()

if(iacc==1):
    print (" ")
    print ("Enter the output filename: ")
    ns=len(y)
    output_file_path = stdin.readline()
    output_file = output_file_path.rstrip('\n')
    WriteData2(ns,ttime,yf,output_file_path)
    
#*******************************************************************************

print (" ")
print (" view plots ")
print (" ")

time_history_plot(ttime,y,1,'Time(sec)',yaxislabel,'Input Time History','input_time_history')

title_label= 'Filtered Time Tistory   fc='+str(fc)+' Hz'
time_history_plot(ttime,yf,2,'Time(sec)',yaxislabel,title_label,'filtered_time_history')

plt.figure(3)
plt.plot(f, H_mag)
title_string= ' Transfer Magnitude  fc='+str(fc)+' Hz'
plt.title(title_string)
plt.xlabel('Frequency (Hz) ')
plt.ylabel('Magnitude')
plt.grid(True, which="both")
plt.xscale('linear')
plt.yscale('linear')
plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0])
plt.draw()

plt.figure(4)
plt.plot(f, H_phase)
title_string= ' Transfer Phase  fc='+str(fc)+' Hz'
plt.title(title_string)
plt.xlabel('Frequency (Hz) ')
plt.ylabel('Phase(deg)')
plt.grid(True, which="both")
plt.xscale('linear')
plt.draw()

#*******************************************************************************

plt.show()