#include "spi.h"
#include "uart1.h"

#define USE_DMA 1

#define SPI_CS_BIT LATAbits.LATA6

uint8_t spi1_buf[2];

void spi_init(void)
{
    // SPI

    // F_SCK = F_SPI / (2 * (SPI1BAUD + 1))
    // During initialization: F_SCK = 64MHz / (2 * (255 + 1)) = 125kHz
    SPI1BAUD = 255;

    // Bit 1 - MST: 1 SPI module operates as the bus host
    // Bit 0 - BMODE: 1
    SPI1CON0 = 0x03;

    // Bit 6 - CKE: 1 Output data changes on transition from Active to Idle clock state
    // Bit 5 - CKP: 0 Idle state for SCK is low level
    // Bit 2 - SSP: 1 SS is active-low
    // 0100 0100
    SPI1CON1 = 0x44;

    // Bit 1 - TXR: 1 TxFIFO data is required for a transfer
    // Bit 0 - RXR: 1 Data transfers are suspended when RxFIFO is full
    SPI1CON2 = 0x03;

    // SPI1TX, SPI1RX default priority is high
    SPI1CON0bits.EN = 1;

    // System arbitration priority. DMA doesn't work with the default lowest priority.
    // Higher than SPI2
    PRLOCK = 0x55; // This sequence
    PRLOCK = 0xAA; // is mandatory
    PRLOCKbits.PRLOCKED = 0; // for DMA operation
    DMA1PR = 0x02; // Change the priority only if needed
    DMA2PR = 0x02; // Change the priority only if needed
    PRLOCK = 0x55; // This sequence
    PRLOCK = 0xAA; // is mandatory
    PRLOCKbits.PRLOCKED = 1; // for DMA operation
}

void spi_transfer(uint8_t* tx_buf, uint8_t* rx_buf, uint16_t length)
{
#if USE_DMA
    if(length == 0) return;

    while(SPI1CON2bits.BUSY || !PIR3bits.SPI1TXIF) {}

    PIR2bits.DMA1DCNTIF = 0;

    // RX DMA (DMA1)
    DMASELECT = 0;
    DMAnCON1 = 0x60; // DMODE = 1, incremented, DSTP = 1
    DMAnSSA = (volatile unsigned short)&SPI1RXB;
    DMAnDSA = (volatile unsigned short)rx_buf;
    DMAnSSZ = 1;
    DMAnDSZ = length;
    DMAnSIRQ = 0x18; // SPI1RX (Serial Peripheral Interface)
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    // TX DMA (DMA2)
    DMASELECT = 1;
    DMAnCON1 = 0x03; // SMODE = 1, incremented, SSTP = 1
    DMAnSSA = (volatile unsigned short)tx_buf;
    DMAnDSA = (volatile unsigned short)&SPI1TXB;
    DMAnSSZ = length;
    DMAnDSZ = 1;
    DMAnSIRQ = 0x19; // SPI1TX
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    while(!PIR2bits.DMA1DCNTIF) {}

    DMASELECT = 0;
    DMAnCON0 = 0;
    DMASELECT = 1;
    DMAnCON0 = 0;
#else
    uint16_t i;
    for(i = 0; i < length; i++)
    {
        SPI1TXB = tx_buf[i];
        while(!PIR3bits.SPI1RXIF) {}
        rx_buf[i] = SPI1RXB;
    }
#endif
}

void spi_receive(uint8_t* rx_buf, uint16_t length)
{
#if USE_DMA
    if(length == 0) return;

    while(SPI1CON2bits.BUSY) {}

    PIR2bits.DMA1DCNTIF = 0;

    // RX DMA
    DMASELECT = 0;
    DMAnCON1 = 0x60; // DMODE = 1, incremented, DSTP = 1
    DMAnSSA = (volatile unsigned short)&SPI1RXB;
    DMAnDSA = (volatile unsigned short)rx_buf;
    DMAnSSZ = 1;
    DMAnDSZ = length;
    DMAnSIRQ = 0x18; // SPI1RX (Serial Peripheral Interface)
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    // TX DMA
    DMASELECT = 1;
    DMAnCON1 = 0x01; // SMODE = 0, incremented, SSTP = 1
    spi1_buf[0] = 0xff;
    DMAnSSA = (volatile unsigned short)&spi1_buf[0];
    DMAnDSA = (volatile unsigned short)&SPI1TXB;
    DMAnSSZ = length;
    DMAnDSZ = 1;
    DMAnSIRQ = 0x19; // SPI1TX
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    while(!PIR2bits.DMA1DCNTIF) {}

    DMASELECT = 0;
    DMAnCON0 = 0;
    DMAnSIRQ = 0;
    DMASELECT = 1;
    DMAnCON0 = 0;
    DMAnSIRQ = 0;
#else
    uint16_t i;
    spi1_buf[0] = 0xff;
    for(i = 0; i < length; i++)
    {
        SPI1TXB = spi1_buf[0];
        while(!PIR3bits.SPI1RXIF) {}
        rx_buf[i] = SPI1RXB;
    }
#endif
}

void spi_send(uint8_t* tx_buf, uint16_t length)
{
#if USE_DMA
    if(length == 0) return;

    while(SPI1CON2bits.BUSY) {}

    PIR2bits.DMA1DCNTIF = 0;

    // RX DMA
    DMASELECT = 0;
    DMAnCON1 = 0x20; // DMODE = 0, incremented, DSTP = 1
    DMAnSSA = (volatile unsigned short)&SPI1RXB;
    DMAnDSA = (volatile unsigned short)&spi1_buf[0];
    DMAnSSZ = 1;
    DMAnDSZ = length;
    DMAnSIRQ = 0x18; // SPI1RX (Serial Peripheral Interface)
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    // TX DMA
    DMASELECT = 1;
    DMAnCON1 = 0x03; // SMODE = 1, incremented, SSTP = 1
    DMAnSSA = (volatile unsigned short)tx_buf;
    DMAnDSA = (volatile unsigned short)&SPI1TXB;
    DMAnSSZ = length;
    DMAnDSZ = 1;
    DMAnSIRQ = 0x19; // SPI1TX
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    while(!PIR2bits.DMA1DCNTIF) {}

    DMASELECT = 0;
    DMAnCON0 = 0;
    DMAnSIRQ = 0;
    DMASELECT = 1;
    DMAnCON0 = 0;
    DMAnSIRQ = 0;
#else
    uint16_t i;
    for(i = 0; i < length; i++)
    {
        SPI1TXB = tx_buf[i];
        while(!PIR3bits.SPI1RXIF) {}
        spi1_buf[0] = SPI1RXB;
    }
#endif
}

void spi_send_repeat(uint8_t byte, uint16_t count)
{
#if USE_DMA
    if(count == 0) return;

    while(SPI1CON2bits.BUSY) {}

    PIR2bits.DMA1DCNTIF = 0;

    // RX DMA
    DMASELECT = 0;
    DMAnCON1 = 0x20; // DMODE = 0, incremented, DSTP = 1
    DMAnSSA = (volatile unsigned short)&SPI1RXB;
    DMAnDSA = (volatile unsigned short)&spi1_buf[0];
    DMAnSSZ = 1;
    DMAnDSZ = count;
    DMAnSIRQ = 0x18; // SPI1RX (Serial Peripheral Interface)
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    // TX DMA
    DMASELECT = 1;
    DMAnCON1 = 0x01; // SMODE = 0, incremented, SSTP = 1
    spi1_buf[1] = byte;
    DMAnSSA = (volatile unsigned short)&spi1_buf[1];
    DMAnDSA = (volatile unsigned short)&SPI1TXB;
    DMAnSSZ = count;
    DMAnDSZ = 1;
    DMAnSIRQ = 0x19; // SPI1TX
    DMAnCON0 = 0xc0; // EN = 1, SIRQEN = 1

    while(!PIR2bits.DMA1DCNTIF) {}

    DMASELECT = 0;
    DMAnCON0 = 0;
    DMAnSIRQ = 0;
    DMASELECT = 1;
    DMAnCON0 = 0;
    DMAnSIRQ = 0;
#else
    uint16_t i;
    for(i = 0; i < count; i++)
    {
        SPI1TXB = byte;
        while(!PIR3bits.SPI1RXIF) {}
        spi1_buf[0] = SPI1RXB;
    }
#endif
}

// baud_rate: kHz
void spi_switch_baud_rate(const uint8_t high_speed)
{
    if(high_speed)
    {
        // F_SCK = F_SPI / (2 * (SPI1BAUD + 1))
        // After initialization: F_SCK = 64MHz / (2 * (1 + 1)) = 16MHz
        SPI1BAUD = 3;
    }
    else
    {
        // F_SCK = F_SPI / (2 * (SPI1BAUD + 1))
        // During initialization: F_SCK = 64MHz / (2 * (255 + 1)) = 125kHz
        SPI1BAUD = 255;
    }
}

uint8_t spi_transfer_byte(const uint8_t byte)
{
    while(SPI1CON2bits.BUSY) {}

    SPI1TXB = byte;
    while(!PIR3bits.SPI1RXIF) {}
    return SPI1RXB;
}

void spi_set_cs(const uint8_t cs)
{
    SPI_CS_BIT = cs;
}