#include "filterbankset.h"

#include "../structures/date.h"

#include "../util/progress/progresslistener.h"

#include "../lua/telescopefile.h"

#include <aocommon/system.h>

#include <fstream>

namespace imagesets {

FilterBankSet::FilterBankSet(const std::string& location)
    : location_(location),
      time_of_sample_(0.0),
      start_time_(0.0),
      bank1_centre_frequency(0.0),
      bank_channel_bandwidth_(0.0),
      channel_count_(0),
      if_count_(0),
      bit_count_(0),
      sample_count_(0),
      n_beams_(0),
      i_beam_(0),
      machine_id_(0),
      interval_count_(0),
      header_end_(0) {
  std::ifstream file(location_.c_str());
  if (!file.good())
    throw std::runtime_error(std::string("Error opening filterbank file ") +
                             location_);

  std::string keyword = ReadString(file);
  if (keyword != "HEADER_START")
    throw std::runtime_error(
        "Filterbank file does not start with 'HEADER_START' -- corrupt file?");

  while (file.good() && keyword != "HEADER_END") {
    keyword = ReadString(file);

    if (keyword == "tsamp")
      time_of_sample_ = ReadDouble(file);
    else if (keyword == "tstart")
      start_time_ = ReadDouble(file);
    else if (keyword == "fch1")
      bank1_centre_frequency = ReadDouble(file);
    else if (keyword == "foff")
      bank_channel_bandwidth_ = ReadDouble(file);
    else if (keyword == "nchans")
      channel_count_ = ReadInt(file);
    else if (keyword == "nifs")
      if_count_ = ReadInt(file);
    else if (keyword == "nbits")
      bit_count_ = ReadInt(file);
    else if (keyword == "nsamples")
      sample_count_ = ReadInt(file);
    else if (keyword == "machine_id")
      machine_id_ = ReadInt(file);
    else if (keyword == "telescope_id")
      telescope_id_ = ReadInt(file);
    else if (keyword == "nbeams")
      n_beams_ = ReadInt(file);
    else if (keyword == "ibeam")
      i_beam_ = ReadInt(file);
    else if (keyword == "src_raj" || keyword == "src_dej" ||
             keyword == "az_start" || keyword == "za_start" ||
             keyword == "refdm" || keyword == "period")
      ReadDouble(file);
    else if (keyword == "data_type" || keyword == "barycentric" ||
             keyword == "pulsarcentric")
      ReadInt(file);
  }
  header_end_ = file.tellg();
  if (sample_count_ == 0) {
    file.seekg(0, std::ios::end);
    const std::streampos endPos = file.tellg();
    const size_t dataSize = endPos - header_end_;
    sample_count_ = (dataSize * 8) / channel_count_ / if_count_ / bit_count_;
  }
  Logger::Debug << "tsamp=" << time_of_sample_ << ", tstart=" << start_time_
                << ", fch1=" << bank1_centre_frequency
                << ", foff=" << bank_channel_bandwidth_ << '\n'
                << "nChans=" << channel_count_ << ", nIFs=" << if_count_
                << ", nBits=" << bit_count_ << ", nSamples=" << sample_count_
                << "\nmachine_ID=" << machine_id_
                << ", telescope_ID=" << telescope_id_ << '\n';

  start_time_ = Date::MJDToAipsMJD(start_time_);

  const double sizeOfImage =
      double(channel_count_) * sample_count_ * if_count_ * sizeof(float);
  const double memSize = aocommon::system::TotalMemory();
  interval_count_ = ceil(sizeOfImage / (memSize / 16.0));
  if (interval_count_ < 1) interval_count_ = 1;
  if (interval_count_ * 8 > sample_count_) interval_count_ = sample_count_ / 8;
  Logger::Debug << round(sizeOfImage * 1e-8) * 0.1
                << " GB/image required of total of "
                << round(memSize * 1e-8) * 0.1 << " GB of mem, splitting in "
                << interval_count_ << " intervals\n";
  if (if_count_ != 1 && if_count_ != 4)
    throw std::runtime_error("Unsupported value for nIFs: " +
                             std::to_string(if_count_));
}

FilterBankSet::FilterBankSet(const FilterBankSet& source)
    : location_(source.location_),
      time_of_sample_(source.time_of_sample_),
      start_time_(source.start_time_),
      bank1_centre_frequency(source.bank1_centre_frequency),
      bank_channel_bandwidth_(source.bank_channel_bandwidth_),
      channel_count_(source.channel_count_),
      if_count_(source.if_count_),
      bit_count_(source.bit_count_),
      sample_count_(source.sample_count_),
      n_beams_(source.n_beams_),
      i_beam_(source.i_beam_),
      machine_id_(source.machine_id_),
      telescope_id_(source.telescope_id_),
      interval_count_(source.interval_count_),
      header_end_(source.header_end_) {}

void FilterBankSet::AddReadRequest(const ImageSetIndex& index) {
  requests_.push_back(new BaselineData(index));
}

std::unique_ptr<BaselineData> FilterBankSet::GetNextRequested() {
  std::unique_ptr<BaselineData> result = std::move(baselines_.front());
  baselines_.pop();
  return result;
}

void FilterBankSet::AddWriteFlagsTask(const ImageSetIndex& index,
                                      std::vector<Mask2DCPtr>& flags) {
  if (bit_count_ != 32 && bit_count_ != 8)
    throw std::runtime_error("This filterbank set uses " +
                             std::to_string(bit_count_) +
                             " bits. "
                             "Only support for 8 or 32-bit filterbank sets has "
                             "been added as of yet.");

  const size_t intervalIndex = index.Value();

  const size_t startIndex = (sample_count_ * intervalIndex) / interval_count_;
  const size_t endIndex =
      (sample_count_ * (intervalIndex + 1)) / interval_count_;

  std::fstream file(location_.c_str());
  file.seekg(header_end_ + std::streampos(startIndex * (bit_count_ / 8) *
                                          channel_count_ * if_count_));

  if (bit_count_ == 32) {
    std::vector<float> buffer(channel_count_ * if_count_);
    for (size_t x = 0; x != endIndex - startIndex; ++x) {
      const std::streampos pos = file.tellg();
      file.read(reinterpret_cast<char*>(buffer.data()),
                channel_count_ * if_count_ * sizeof(float));
      float* buffer_ptr = buffer.data();
      for (size_t p = 0; p != if_count_; ++p) {
        for (size_t y = 0; y != channel_count_; ++y) {
          if (flags[0]->Value(x, y))
            *buffer_ptr = std::numeric_limits<float>::quiet_NaN();
          ++buffer_ptr;
        }
      }
      file.seekp(pos);
      file.write(reinterpret_cast<char*>(buffer.data()),
                 channel_count_ * if_count_ * sizeof(float));
    }
  } else {  // 8 bits
    std::vector<unsigned char> buffer(channel_count_ * if_count_);
    for (size_t x = 0; x != endIndex - startIndex; ++x) {
      const std::streampos pos = file.tellg();
      file.read(reinterpret_cast<char*>(buffer.data()),
                channel_count_ * if_count_);
      unsigned char* buffer_ptr = buffer.data();
      for (size_t p = 0; p != if_count_; ++p) {
        for (size_t y = 0; y != channel_count_; ++y) {
          if (flags[0]->Value(x, y)) *buffer_ptr = -127;
          ++buffer_ptr;
        }
      }
      file.seekp(pos);
      file.write(reinterpret_cast<char*>(buffer.data()),
                 channel_count_ * if_count_);
    }
  }
}

std::string FilterBankSet::TelescopeName() {
  return TelescopeFile::TelescopeName(TelescopeFile::GENERIC_TELESCOPE);
}

void FilterBankSet::Initialize() {}

void FilterBankSet::PerformReadRequests(ProgressListener& progress) {
  if (bit_count_ != 32 && bit_count_ != 8)
    throw std::runtime_error("This filterbank set uses " +
                             std::to_string(bit_count_) +
                             " bits. "
                             "Only support for 8 or 32-bit filterbank sets has "
                             "been added as of yet.");
  const size_t n_requests = requests_.size();
  while (!requests_.empty()) {
    std::unique_ptr<BaselineData> baseline(std::move(requests_.front()));
    requests_.pop_front();
    const size_t intervalIndex = baseline->Index().Value();

    const size_t startIndex = (sample_count_ * intervalIndex) / interval_count_;
    const size_t endIndex =
        startIndex + (sample_count_ * (intervalIndex + 1)) / interval_count_;

    std::ifstream file(location_.c_str());
    file.seekg(header_end_ + std::streampos(startIndex * (bit_count_ / 8) *
                                            channel_count_ * if_count_));

    std::vector<Image2DPtr> images(if_count_);
    const size_t width = endIndex - startIndex;
    for (Image2DPtr& image : images)
      image = Image2D::CreateUnsetImagePtr(width, channel_count_);
    std::vector<Mask2DPtr> masks(if_count_);
    for (Mask2DPtr& mask : masks)
      mask = Mask2D::CreateUnsetMaskPtr(width, channel_count_);
    if (bit_count_ == 32) {
      std::vector<float> buffer(channel_count_ * if_count_);
      for (size_t x = 0; x != width; ++x) {
        progress.OnProgress(x + width * (n_requests - requests_.size() - 1),
                            width * n_requests);
        file.read(reinterpret_cast<char*>(buffer.data()),
                  channel_count_ * if_count_ * sizeof(float));
        const float* buffer_ptr = buffer.data();
        for (size_t p = 0; p != if_count_; ++p) {
          for (size_t y = 0; y != channel_count_; ++y) {
            images[p]->SetValue(x, y, *buffer_ptr);
            masks[p]->SetValue(x, y, !std::isfinite(*buffer_ptr));
            ++buffer_ptr;
          }
        }
      }
    } else {
      std::vector<unsigned char> buffer(channel_count_ * if_count_);
      for (size_t x = 0; x != width; ++x) {
        progress.OnProgress(x + width * (n_requests - requests_.size() - 1),
                            width * n_requests);
        file.read(reinterpret_cast<char*>(buffer.data()),
                  channel_count_ * if_count_);
        const unsigned char* buffer_ptr = buffer.data();
        for (size_t p = 0; p != if_count_; ++p) {
          for (size_t y = 0; y != channel_count_; ++y) {
            images[p]->SetValue(x, y, static_cast<float>(*buffer_ptr));
            masks[p]->SetValue(x, y, *buffer_ptr == 255);
            ++buffer_ptr;
          }
        }
      }
    }
    TimeFrequencyData tfData;
    if (if_count_ == 1) {
      tfData = TimeFrequencyData(TimeFrequencyData::AmplitudePart,
                                 aocommon::Polarization::StokesI, images[0]);
      tfData.SetGlobalMask(masks[0]);
    } else {
      tfData = TimeFrequencyData::FromStokes(images[0], images[1], images[2],
                                             images[3]);
      for (size_t i = 0; i != if_count_; ++i)
        tfData.SetIndividualPolarizationMasks(masks.data());
    }
    const TimeFrequencyMetaDataPtr metaData(new TimeFrequencyMetaData());
    AntennaInfo antenna;
    antenna.diameter = 0;
    antenna.id = 0;
    antenna.mount = "unknown";
    antenna.name = "unknown";
    antenna.position = EarthPosition();
    antenna.station = "unknown";
    metaData->SetAntenna1(antenna);
    metaData->SetAntenna2(antenna);
    BandInfo band;
    band.windowIndex = 0;
    for (size_t ch = 0; ch != channel_count_; ++ch) {
      ChannelInfo channel;
      channel.frequencyHz =
          1e6 * (bank1_centre_frequency + bank_channel_bandwidth_ * ch);
      channel.effectiveBandWidthHz = 1e6 * std::fabs(bank_channel_bandwidth_);
      channel.frequencyIndex = ch;
      channel.channelWidthHz = 1e6 * std::fabs(bank_channel_bandwidth_);
      channel.resolutionHz = 1e6 * std::fabs(bank_channel_bandwidth_);
      band.channels.push_back(channel);
    }
    metaData->SetBand(band);
    std::vector<double> observationTimes(endIndex - startIndex);
    for (size_t t = startIndex; t != endIndex; ++t)
      observationTimes[t - startIndex] = (start_time_ + time_of_sample_ * t);
    metaData->SetObservationTimes(observationTimes);
    metaData->SetValueDescription("Power");

    baseline->SetData(tfData);
    baseline->SetMetaData(metaData);
    baselines_.emplace(std::move(baseline));
  }
  progress.OnFinish();
}

void FilterBankSet::PerformWriteDataTask(
    const ImageSetIndex& index, std::vector<Image2DCPtr> realImages,
    std::vector<Image2DCPtr> imaginaryImages) {
  throw std::runtime_error(
      "Can't write data back to filter bank format: not implemented");
}

std::string FilterBankSet::Description(const ImageSetIndex& index) const {
  std::ostringstream str;
  str << "Filterbank set -- interval " << (index.Value() + 1) << '/'
      << interval_count_;
  return str.str();
}

}  // namespace imagesets
