#include "macros.h"
#include <string.h>
char *strsep2(char** state, const char* delims) {
  char* result;
  do result = strsep(state,delims); while (result && *result == '\0');
  return result;
}

#include "organ.h"
#include "plucksynth.h"
#include "reverb.h"
#include "reverb2.h"
#include "reverb3.h"
#include "chorus.h"
#include "formanter.h"
#include "simplesynth.h"
#include "resobass.h"

#include "wavwriter.h"
#include <portaudio.h>
#include <portmidi.h>
#include <stdio.h>
#include <math.h>
#include "interact.h"
#include "tunings.h"
#include <stdlib.h>
#include <ctype.h>

#define SAMPLERATE 48000

struct state {
  // subcomponents

  struct {
    struct synthdesc* synthdesc;
    void* synthstate;
  } chain[4];

  // state
  PmStream *pmstream;
  double lpfstate;
  int counter;
  FILE* output;
  FILE* mic_output;

  float buffer[8][512];
};

void settuning(struct state* state, float freqtable[128]) {
  for(int i=0;state->chain[i].synthdesc;i++) {
    if(state->chain[i].synthdesc->retune) {
      state->chain[i].synthdesc->retune(state->chain[i].synthstate,freqtable);
    }
  }
}

cmd_action cmd_set, cmd_enter, cmd_meantone, cmd_just, cmd_cents, cmd_freqs;

int cmd_freqs(void* actionparam, void* vstate, char* line) {
  CAST(struct state*,state,vstate);
  FILE* f = fopen(line,"rt");
  if (!f) {
    fprintf(stderr,"Failed to open file: %s\n", line);
    return 0;
  }
  float freqtable[128];
  for(int i=0;i<128;i++) {
    freqtable[i]=0;
    fscanf(f,"%f",&freqtable[i]);
  }
  settuning(state,freqtable);
  fclose(f);
}

int cmd_just(void* actionparam, void* vstate, char* line) {
  CAST(struct state*,state,vstate);
  float freqtable[128];
  while(line != NULL && isspace(*line))
    line++;
  float octavetable[12];
  for(int i=0;i<12;i++)
    octavetable[i]=1.0;
  const char* note=NULL;
  for(int i = 0; i < 12 && (note=strsep2(&line," \r\n\t")); i++) {
    int n,d;
    if (parseratio(note,&n,&d)) {
      octavetable[i] = (float)n/(float)d;
    }
  }
  makeoctaves(octavetable,2.0,freqtable);
  settuning(state,freqtable);
  return 0;
}
int cmd_cents(void* actionparam, void* vstate, char* line) {
  CAST(struct state*,state,vstate);
  float freqtable[128];
  while(line != NULL && isspace(*line))
    line++;
  float octavetable[12];
  for(int i=0;i<12;i++)
    octavetable[i]=1.0;
  const char* note=NULL;
  for(int i = 0; i < 12 && (note=strsep2(&line," \r\n\t")); i++) {
    double cents=0;
    if (sscanf(note,"%lf",&cents) == 1) {
      octavetable[i] = pow(2.0,cents/1200.0);
    }
  }
  makeoctaves(octavetable,2.0,freqtable);
  settuning(state,freqtable);
  return 0;
}
int cmd_meantone(void* actionparam, void* vstate, char* line) {
  CAST(struct state*,state,vstate);
  while(isspace(*line))
    line++;
  char* octavecentsstr = strsep2(&line," \n\r\t");
  char* fifthcentsstr = strsep2(&line," \n\r\t");
  double octavecents=1200.0;
  double fifthcents=700.0;
  if (sscanf(octavecentsstr,"%lf",&octavecents) == 1
      && sscanf(fifthcentsstr,"%lf",&fifthcents) == 1) {
    float freqtable[128];
    while(line != NULL && isspace(*line)) line++;
    if (line == NULL || *line == 0)  {
      printf("Default keyboard layout\n");
      makemeantone(octavecents,fifthcents,fifthslayout_normal,freqtable);
    }
    else {
      printf("Custom keyboard layout\n");
      int layout[12];
      int i=0;
      const char* note;
      while(i < 12 && (note=strsep2(&line," \r\n\t"))) {
        layout[i++]=parsenote(note);
      }
      while(i < 12) {
        layout[i++]=0;
      }
      makemeantone(octavecents,fifthcents,layout,freqtable);
    }

    settuning(state,freqtable);
  }
  else {
    printf("error, expected: octave-cents fifth-cents\n");
  }
  return 0;
}

static struct command commands[] = {
  { "help", cmd_help, commands, "prints list of availible commands" },
  { "exit", cmd_exit, NULL, "quits program" },
  { "just", cmd_just, NULL, "tunes into just tuning given as ratios" },
  { "freqs", cmd_freqs, NULL, "reads a list of 128 frequencies and tunes the keyboard to them" },
  { "meantone", cmd_meantone, NULL, "tunes into meantone or schismatic given octave cents and fifth cents" },
  { "cents", cmd_cents, NULL, "tunes to a 12-note octave scale specified in cents" },
  //  { "set", cmd_set, NULL, "sets a parameter to a value" },
  //  { "cd", cmd_enter, NULL, "enter synth settings" },
  { NULL, NULL }
};

static int pa_callback(const void *inputBuffer, void *outputBuffer,
		       unsigned long framesPerBuffer,
		       const PaStreamCallbackTimeInfo* timeInfo,
		       PaStreamCallbackFlags statusFlags,
		       void *userData )
{
  /* Cast data passed through stream to our structure. */
  struct state* state = (struct state*)userData;
  short const *in = (short const*)inputBuffer;
  if (state->mic_output) {
    fwrite(in,framesPerBuffer*2*2,1,state->mic_output);
  }  
  short *out = (short*)outputBuffer;

  while(Pm_Poll(state->pmstream) == TRUE) {
    PmEvent event;
    int channel = Pm_MessageStatus(event.message)&0x0F;
    Pm_Read(state->pmstream,&event,1);
    switch(Pm_MessageStatus(event.message)&0xF0) {
    case 0x80: // Note off
      for(int i=0;state->chain[i].synthdesc;i++) {
        if(state->chain[i].synthdesc->keyup) {
          state->chain[i].synthdesc->keyup(state->chain[i].synthstate,Pm_MessageData1(event.message)); break;
        }
      }
      break;
    case 0x90: // Note on
      if(Pm_MessageData2(event.message) != 0) {
        for(int i=0;state->chain[i].synthdesc;i++) {
          if(state->chain[i].synthdesc->keydown) {
            state->chain[i].synthdesc->keydown(state->chain[i].synthstate,Pm_MessageData1(event.message),Pm_MessageData2(event.message)/127.0); break;
          }
        }
      }
      else {
        for(int i=0;state->chain[i].synthdesc;i++) {
          if(state->chain[i].synthdesc->keyup) {
            state->chain[i].synthdesc->keyup(state->chain[i].synthstate,Pm_MessageData1(event.message)); break;
          }
        }
      }
      break;
    case 0xB0: // CC
      switch(Pm_MessageData1(event.message)) {
        
      case 1: // mod
        for(int i=0;state->chain[i].synthdesc;i++) {
          if(state->chain[i].synthdesc->mod) {
            state->chain[i].synthdesc->mod(state->chain[i].synthstate,Pm_MessageData2(event.message)/127.0); break;
          }
        }
        break;
      case 7: // vol
        for(int i=0;state->chain[i].synthdesc;i++) {
          if(state->chain[i].synthdesc->vol) {
            state->chain[i].synthdesc->vol(state->chain[i].synthstate,Pm_MessageData2(event.message)/127.0); break;
          }
        }
        break;
      default:
        printf("CC %i\n",(int)Pm_MessageData1(event.message));
        break;
      }
      break;
    case 0xE0: // pitch bend
      {
        int bend14 = (int)Pm_MessageData1(event.message)
          +128*(int)Pm_MessageData2(event.message);
        double bend = bend14*(2.0f/16384.0f)-1.0;
        double cents = bend*200;
        
        for(int i=0;state->chain[i].synthdesc;i++) {
          if(state->chain[i].synthdesc->pitchbend) {
            state->chain[i].synthdesc->pitchbend(state->chain[i].synthstate,cents); break;
          }
        }
        break;
      }
      /*
        switch(channel) {
        case 0:
        plucksynth_pitchbend(&state->plucksynth,bend);
        break;
        }
      */
    };
  }

  float* restrict ins[2]={state->buffer[0],state->buffer[1]};
  float* restrict outs[2]={state->buffer[2],state->buffer[3]};

  for(int i=0;state->chain[i].synthdesc;i++) {
    state->chain[i].synthdesc->process(state->chain[i].synthstate,framesPerBuffer,ins,outs);
    float* restrict tmp;
    tmp=ins[0];ins[0]=outs[0];outs[0]=tmp;
    tmp=ins[1];ins[1]=outs[1];outs[1]=tmp;
  }
  
  for(int i=0; i<framesPerBuffer; i++ ) {
    float l = ins[0][i];
    float r = ins[1][i];
    
    if(l > 1.0) l = 1.0;
    if(l < -1.0) l = -1.0;
    if(r > 1.0) r = 1.0;
    if(r < -1.0) r = -1.0;
    out[2*i+0] = (short)(l*32767-rand()*(1.0/RAND_MAX));
    out[2*i+1] = (short)(r*32767-rand()*(1.0/RAND_MAX));
  }

  if (state->output) {
    fwrite(out,framesPerBuffer*2*2,1,state->output);
  }
  return 0;
}


void print_pa_error(const char* context) {
  const PaHostErrorInfo *info = Pa_GetLastHostErrorInfo();
  fprintf(stderr,"%s: Portaudio error %li: %s\n",context,info->errorCode,info->errorText);
}
void print_pm_error() {
  char msg[2000];
  Pm_GetHostErrorText(msg,2000);
  fprintf(stderr,"%s\n",msg);
}

int main() {
  struct state state;
  PmError pmerr;
  PaError paerr;
  PaStream *pastream;
  float freqtable[128];

  float octavecents = 1200;
  float fifthcents = 696.58; // 1/4 comma meantone
  makemeantone(octavecents, fifthcents, fifthslayout_normal, freqtable);

  paerr = Pa_Initialize();
  if(paerr != paNoError) { print_pa_error("Pa_Initialize"); return 1; }

  pmerr = Pm_Initialize();
  if(pmerr != pmNoError) { print_pm_error(); return 1; }

  state.counter = 0;
  state.lpfstate = 0;
  state.output = wavwriter_begin("out.wav",SAMPLERATE);
  state.mic_output = wavwriter_begin("out_mic.wav",SAMPLERATE);
  int p = 0;
  //state.chain[p++].synthdesc = &resobassdesc;
  state.chain[p++].synthdesc = &organdesc;
  //state.chain[p++].synthdesc = &plucksynthdesc;
  //state.chain[p++].synthdesc = &simplesynthdesc;
  state.chain[p++].synthdesc = &reverb3desc;
  //state.chain[p++].synthdesc = &chorusdesc;
  state.chain[p++].synthdesc = NULL;
  for(int i=0;i<state.chain[i].synthdesc;i++) {
    state.chain[i].synthstate = malloc(state.chain[i].synthdesc->size);
    state.chain[i].synthdesc->init(state.chain[i].synthstate, SAMPLERATE, freqtable);
  }

  const char* filename = "startup.txt";
  FILE* f = fopen(filename,"rt");
  if (!f) {
    printf("Unable to run script file: %s\n", filename);
  }
  else {
    runscript(&state,commands,f);
    fclose(f);
    f=NULL;
  }

  // list midi devices

  {
    int device_count = Pm_CountDevices();
    for(int i=0;i<device_count;i++) {
      PmDeviceInfo const* info = Pm_GetDeviceInfo(i);
      if(info != NULL) {
        printf("Portmidi device %i:\n",i);
        printf("    Interface: %s\n",info->interf);
        printf("    Name: %s\n",info->name);
        printf("    In/out/opened: %i/%i/%i\n", info->input, info->output, info->opened);
      }
    }
    int default_midi = Pm_GetDefaultInputDeviceID();
    printf("Default midi: %i\n",default_midi);
  }

  // open midi device

  pmerr = Pm_OpenInput(&state.pmstream,
		       5,
		       NULL,
		       1000,
		       0, /* latency (ms) */
		       NULL);
  if(pmerr != pmNoError) { print_pm_error(); return 1; }

  paerr = Pa_StartStream(pastream);

  // list portaudio apis

  {
    int api_count = 0;
    const PaHostApiInfo* api;
    while(api = Pa_GetHostApiInfo(api_count++)) {
      printf("Host API: %s\n",api->name);
      printf("  type: %i\n", api->type);
      printf("  device count: %i\n",api->deviceCount);
      printf("  default input device: %i\n",api->defaultInputDevice);
      printf("  default output device: %i\n",api->defaultOutputDevice);
    }
  }

  // list audio devices

  {
    int device_count = Pa_GetDeviceCount();
    for(int i=0;i<device_count;i++) {
      PaDeviceInfo const* info = Pa_GetDeviceInfo(i);
      printf("Portaudio device %i:\n",i);
      printf("    Name: %s\n",info->name);
      PaHostApiInfo* api = Pa_GetHostApiInfo(info->hostApi);
      printf("    Host API: %s\n",api->name);
      printf("    In/Out/Samplerate: %i/%i/%lf\n",info->maxInputChannels,info->maxOutputChannels,info->defaultSampleRate);
    }
  }

  // open audio device

  PaStreamParameters inputParameters={
    .device=8,
    .channelCount=2,
    .sampleFormat=paInt16,
    .suggestedLatency=0.030,
    .hostApiSpecificStreamInfo=NULL,
  };
  PaStreamParameters outputParameters={
    .device=8,
    .channelCount=2,
    .sampleFormat=paInt16,
    .suggestedLatency=0.030,
    .hostApiSpecificStreamInfo=NULL,
  };
  PaStreamFlags streamFlags = paNoFlag;

  /* Open an audio I/O stream. */
  paerr = Pa_OpenStream(&pastream,
                        &inputParameters,
                        &outputParameters,
                        SAMPLERATE,
                        256,        /* frames per buffer, i.e. the number
                                       of sample frames that PortAudio will
                                       request from the callback */
                        streamFlags,
                        pa_callback, /* this is your callback function */
                        &state ); /*This is a pointer that will be passed to
                                    your callback*/
  if(paerr != paNoError) {
    print_pa_error("Pa_OpenStream");
    return 1;
  }

  paerr = Pa_StartStream(pastream);
  if(paerr != paNoError) { print_pa_error("Pa_StartStream"); return 1; }

  interact(&state,commands,"main> ");

  paerr = Pa_StopStream(pastream);
  if(paerr != paNoError) { print_pa_error("Pa_StopStream"); return 1; }

  paerr = Pa_CloseStream(pastream);
  pastream = NULL;
  if(paerr != paNoError) { print_pa_error("Pa_CloseStream"); return 1; }

  wavwriter_end(state.output);
  state.output = NULL;
  wavwriter_end(state.mic_output);
  state.mic_output = NULL;

  pmerr = Pm_Close(state.pmstream);
  state.pmstream = NULL;
  if(pmerr != pmNoError) { print_pm_error(); return 1; }

  paerr = Pa_Terminate();
  if(paerr != paNoError) { print_pa_error("Pa_Terminate"); return 1; }
  pmerr = Pm_Terminate();
  if(pmerr != pmNoError) { print_pm_error(); return 1; }

  for(int i=0;state.chain[i].synthdesc;i++) {
    state.chain[i].synthdesc->finalize(state.chain[i].synthstate);
    state.chain[i].synthstate = NULL;
  }

  return 0;
}
