自学内容网 自学内容网

CMake实现跨平台高性能算子测试框架

之前使用pybind11作为转换器,将kernel编译为.so文件遇到过一个问题,那就是.cu文件的调用函数接口接受的参数是类似于torch::Tensor或者是numpy,这个编译过程比较慢,而且对于cuda代码来说,python提供了CUDAExtension的接口作为转换,但是对于CPU以及国产芯片比如说MLU就不知道如何替换,为此这里我们使用CMake来实现不同平台的kernel编译。代码架构如下所示:
在这里插入图片描述

其中src文件夹存放的是不同算子的源文件,包括不同平台的kernel,test文件夹存放的是不同算子对应的python端测试脚本,CMakeLists.txt编写了针对不同平台的编译选项,run.sh默认使用CPU平台编译。

src

attention

attention.cu

#include <cuda_runtime.h>
#include <stdio.h> // 确保包含这个头文件
#include <math.h>
#define cudaCheckError(ans)                   \
    {                                         \
        gpuAssert((ans), __FILE__, __LINE__); \
    }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true)
{
    if (code != cudaSuccess)
    {
        fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
        if (abort)
            exit(code);
    }
}

const int Rq = 8;
const int Rv = 8; // 必须是4的倍数
const int Br = 16;
const int Bc = 16;
const int Bk = 8; // 必须是4的倍数
const int Bd = 8;
const int numQ = Rq * Br;
const int numK = Bk * Bc;
const int numV = Rv * Bc;

__device__ void matmulRQK(const float *__restrict inputQ,
                          const float *__restrict inputK, float *shareQK,
                          float *shareVK, int N, int d, int width, int indQ,
                          int indK, float *val)
{
    int tid = threadIdx.y * blockDim.x + threadIdx.x;
    float4 a[1];
    float4 b[1];
    float com_a[Rq];
    float com_b[Bk];

    int smem_a_m = tid / 2;
    int smem_a_k = tid % 2;
    int smem_b_n = tid / 128;
    int smem_b_k = tid % 128;
    // float tmp[64];
    // memset(tmp, 0.0f, sizeof(tmp));
    int ph = 0;
    (float4 &)a[0] =
        (float4 &)inputQ[(indQ + smem_a_m) * d + ph * Bd + 4 * smem_a_k];

    (float4 &)b[0] =
        (float4 &)inputK[(indK + smem_b_k) * d + Bd * ph + 4 * smem_b_n];
    shareQK[(4 * smem_a_k) * numQ + smem_a_m] = a[0].x;
    shareQK[(4 * smem_a_k + 1) * numQ + smem_a_m] = a[0].y;
    shareQK[(4 * smem_a_k + 2) * numQ + smem_a_m] = a[0].z;
    shareQK[(4 * smem_a_k + 3) * numQ + smem_a_m] = a[0].w;

    shareVK[(4 * smem_b_n) * numK + smem_b_k] = b[0].x;
    shareVK[(4 * smem_b_n + 1) * numK + smem_b_k] = b[0].y;
    shareVK[(4 * smem_b_n + 2) * numK + smem_b_k] = b[0].z;
    shareVK[(4 * smem_b_n + 3) * numK + smem_b_k] = b[0].w;
    __syncthreads();
    for (int ph = 1; ph < width; ph++)
    {
        (float4 &)a[0] =
            (float4 &)inputQ[(indQ + smem_a_m) * d + ph * Bd + 4 * smem_a_k];

        (float4 &)b[0] =
            (float4 &)inputK[(indK + smem_b_k) * d + Bd * ph + 4 * smem_b_n];

        for (int index = 0; index < Bd; index++)
        {
            (float4 &)com_a[0] =
                (float4 &)shareQK[index * numQ + threadIdx.y * Rq +
                                  (ph - 1) % 2 * numQ * Bd];
            (float4 &)com_a[4] =
                (float4 &)shareQK[index * numQ + threadIdx.y * Rq + 4 +
                                  (ph - 1) % 2 * numQ * Bd];
            (float4 &)com_b[0] =
                (float4 &)shareVK[index * numK + threadIdx.x * Bk +
                                  (ph - 1) % 2 * numK * Bd];
            (float4 &)com_b[4] =
                (float4 &)shareVK[index * numK + threadIdx.x * Bk + 4 +
                                  (ph - 1) % 2 * numK * Bd];

            for (int index_q = 0; index_q < Rq; index_q++)
            {
                for (int index_k = 0; index_k < Bk; index_k++)
                {

                    val[index_q * Rq + index_k] += com_a[index_q] * com_b[index_k];
                }
            }
        }

        shareQK[(4 * smem_a_k) * numQ + smem_a_m + (ph % 2) * numQ * Bd] =
            a[0].x;
        shareQK[(4 * smem_a_k + 1) * numQ + smem_a_m + (ph % 2) * numQ * Bd] =
            a[0].y;
        shareQK[(4 * smem_a_k + 2) * numQ + smem_a_m + (ph % 2) * numQ * Bd] =
            a[0].z;
        shareQK[(4 * smem_a_k + 3) * numQ + smem_a_m + (ph % 2) * numQ * Bd] =
            a[0].w;

        shareVK[(4 * smem_b_n) * numK + smem_b_k + (ph % 2) * numK * Bd] =
            b[0].x;
        shareVK[(4 * smem_b_n + 1) * numK + smem_b_k + (ph % 2) * numK * Bd] =
            b[0].y;
        shareVK[(4 * smem_b_n + 2) * numK + smem_b_k + (ph % 2) * numK * Bd] =
            b[0].z;
        shareVK[(4 * smem_b_n + 3) * numK + smem_b_k + (ph % 2) * numK * Bd] =
            b[0].w;

        __syncthreads();
    }
    ph = width;
    for (int index = 0; index < Bd; index++)
    {
        (float4 &)com_a[0] =
            (float4 &)shareQK[index * numQ + threadIdx.y * Rq +
                              (ph - 1) % 2 * numQ * Bd];
        (float4 &)com_a[4] =
            (float4 &)shareQK[index * numQ + threadIdx.y * Rq + 4 +
                              (ph - 1) % 2 * numQ * Bd];
        (float4 &)com_b[0] =
            (float4 &)shareVK[index * numK + threadIdx.x * Bk +
                              (ph - 1) % 2 * numK * Bd];
        (float4 &)com_b[4] =
            (float4 &)shareVK[index * numK + threadIdx.x * Bk + 4 +
                              (ph - 1) % 2 * numK * Bd];

        for (int index_q = 0; index_q < Rq; index_q++)
        {
            for (int index_k = 0; index_k < Bk; index_k++)
            {

                val[index_q * Rq + index_k] += com_a[index_q] * com_b[index_k];
            }
        }
    }
}

__device__ void matmulSV(float *shareQK, const float *__restrict inputV,
                         float *shareVK, int N, int d, int j, int indQ,
                         int indK, int indV, float *val, float *newMax,
                         float *sumSV)
{
    for (int index_k = 0; index_k < Bk; index_k++)
    {
        for (int id = 0; id < Rv; id += 4)
        {
            (float4 &)shareVK[threadIdx.y * numV + threadIdx.x * Rv + id] =
                (float4 &)inputV[(indK + threadIdx.y * Bk + index_k) * d +
                                 indV + threadIdx.x * Rv + id];
        }
        for (int index_v = 0; index_v < Rv; index_v++)
        {
            if (indK + threadIdx.y * Bk + index_k >= N ||
                indV + threadIdx.x * Rv + index_v >= d)
            {
                shareVK[threadIdx.y * numV + threadIdx.x * Rv + index_v] = 0.0f;
            }
        }
        for (int index_q = 0; index_q < Rq; index_q++)
        {
            if (indQ + threadIdx.y * Rq + index_q < N &&
                indK + Bk * threadIdx.x + index_k < N)
            {
                shareQK[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] =
                    __expf(val[index_q * Bk + index_k] - newMax[index_q]);
            }
            else
            {

                shareQK[(threadIdx.y * Rq + index_q) * Bc + threadIdx.x] = 0.0f;
            }
        }
        __syncthreads();
        for (int phc = 0; phc < Bc; phc++)
        {
            for (int index_q = 0; index_q < Rq; index_q++)
            {
                for (int index_v = 0; index_v < Rv; index_v++)
                {
                    sumSV[index_q * Rv + index_v] +=
                        shareQK[(threadIdx.y * Rq + index_q) * Bc + phc] *
                        shareVK[phc * numV + threadIdx.x * Rv + index_v];
                }
            }
        }
        __syncthreads();
    }
}
template <typename T>
struct SumOp
{
    __device__ __forceinline__ T operator()(const T &a, const T &b) const
    {
        return a + b;
    }
};

template <typename T>
struct MaxOp
{
    __device__ __forceinline__ T operator()(const T &a, const T &b) const
    {
        return max(a, b);
    }
};
template <template <typename> class ReductionOp, typename T,
          int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val)
{
    for (int mask = thread_group_width / 2; mask > 0; mask >>= 1)
    {
        val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
    }

    return val;
}

template <int Br, int Bc, int Rq, int Rv>
__global__ void _attentionKernel(const float *__restrict inputQ,
                                 const float *__restrict inputK,
                                 const float *__restrict inputV, int N, int d,
                                 float *__restrict output)
{

    __shared__ float shareQK[numQ * Bc];
    __shared__ float shareVK[Bc * numV];
    __shared__ float block_max[numQ];
    __shared__ float block_sum[numQ];
    float sumSV[Rq * Rv] = {0.0f};
    float newMax[Rq];
    float oldMax[Rq];
    float newSum[Rq] = {0.0f};

    float val[Rq * Bk];

    int indV = Rv * blockIdx.x * blockDim.x;
    int indQ = Rq * blockIdx.y * blockDim.y;

    for (int index_q = 0; index_q < Rq; index_q++)
    {
        newMax[index_q] = -__FLT_MAX__;
        oldMax[index_q] = -__FLT_MAX__;
    }

    int Tc = (N + numK - 1) / (numK);

    int width = (d + Bd - 1) / Bd;
    for (int j = 0; j < Tc; j++)
    {

        int indK = j * numK;
        for (int index_q = 0; index_q < Rq; index_q++)
        {
            for (int index_k = 0; index_k < Bk; index_k++)
            {

                val[index_q * Bk + index_k] = 0.0f;
            }
        }
        matmulRQK(inputQ, inputK, shareQK, shareVK, N, d, width, indQ, indK,
                  val);
        for (int index_q = 0; index_q < Rq; index_q++)
        {
            float tmpReduceMax = -__FLT_MAX__;
            for (int index_k = 0; index_k < Bk; index_k++)
            {
                if (indQ + threadIdx.y * Rq + index_q < N &&
                    indK + Bk * threadIdx.x + index_k < N)
                {

                    tmpReduceMax =
                        max(tmpReduceMax, val[index_q * Bk + index_k]);
                }
            }
            __syncthreads();
            tmpReduceMax = WarpAllReduce<MaxOp, float, Bc>(tmpReduceMax);
            if (threadIdx.x == 0)
            {
                block_max[threadIdx.y * Rq + index_q] = tmpReduceMax;
            }
            __syncthreads();
            float tmpReduceSum = 0.0f;
            for (int index_k = 0; index_k < Bk; index_k++)
            {
                if (indQ + threadIdx.y * Rq + index_q < N &&
                    indK + Bk * threadIdx.x + index_k < N)
                {
                    tmpReduceSum +=
                        __expf(val[index_q * Bk + index_k] -
                               block_max[threadIdx.y * Rq + index_q]);
                }
            }
            __syncthreads();
            tmpReduceSum = WarpAllReduce<SumOp, float, Bc>(tmpReduceSum);
            if (threadIdx.x == 0)
            {
                block_sum[threadIdx.y * Rq + index_q] = tmpReduceSum;
            }
            __syncthreads();
            if (newMax[index_q] > block_max[threadIdx.y * Rq + index_q])
            {
                newSum[index_q] =
                    std::fma(block_sum[threadIdx.y * Rq + index_q],
                             __expf(block_max[threadIdx.y * Rq + index_q] -
                                    newMax[index_q]),
                             newSum[index_q]);
            }
            else
            {
                newSum[index_q] =
                    std::fma(newSum[index_q],
                             __expf(newMax[index_q] -
                                    block_max[threadIdx.y * Rq + index_q]),
                             block_sum[threadIdx.y * Rq + index_q]);

                newMax[index_q] = block_max[threadIdx.y * Rq + index_q];
            }
            // PV
            for (int index_v = 0; index_v < Rv; index_v++)
            {
                sumSV[index_q * Rv + index_v] *=
                    __expf(oldMax[index_q] - newMax[index_q]);
            }
        }

        matmulSV(shareQK, inputV, shareVK, N, d, j, indQ, indK, indV, val,
                 newMax, sumSV);

        for (int index_q = 0; index_q < Rq; index_q++)
        {
            oldMax[index_q] = newMax[index_q];
        }

        __syncthreads();
    }
    for (int index_q = 0; index_q < Rq; index_q++)
    {
        float inv = __fdividef(1.0F, newSum[index_q]);
        for (int index_v = 0; index_v < Rv; index_v++)
        {
            sumSV[index_q * Rv + index_v] = sumSV[index_q * Rv + index_v] * inv;
        }
    }
    for (int index_q = 0; index_q < Rq; index_q++)
    {

        for (int id = 0; id < Rv; id += 4)
        {
            if (indQ + threadIdx.y * Rq + index_q < N &&
                indV + threadIdx.x * Rv + id < d)
            {
                (float4 &)output[(indQ + threadIdx.y * Rq + index_q) * d +
                                 indV + threadIdx.x * Rv + id] =
                    (float4 &)sumSV[index_q * Rv + id];
            }
        }
    }
}

extern "C" void attention_nv_f32(float *inputQ, float *inputK, float *inputV, int N, int d, float *output)
{

    int num_block_x = (d + Rv * Bc - 1) / (Rv * Bc);
    int num_block_y = (N + Rq * Br - 1) / (Rq * Br);
    dim3 grid_dim(num_block_x, num_block_y, 1);
    dim3 block_dim(Bc, Br, 1);

    _attentionKernel<Br, Bc, Rq, Rv>
        <<<grid_dim, block_dim>>>(inputQ, inputK, inputV, N, d, output);
    cudaCheckError(cudaPeekAtLastError());
    cudaCheckError(cudaDeviceSynchronize());
}

softmax

softmax.cpp

#include <stdio.h>
#include <math.h>
extern "C" void softmax_cpu_f32(float *input, float *output, int size, int dimsize, int stride)
{
    int othersize = size / dimsize;
    for (int ind = 0; ind < othersize; ind++)
    {                                                            // ind = i(KS) + k(S) + s
        int tid = ind % stride + (ind - ind % stride) * dimsize; // now, tid = i(JKS) + k(S) + s;
        float localM = -__FLT_MAX__;
        for (int j = 0; j < dimsize; j++)
        {
            int index = tid + j * stride;
            localM = fmax(localM, input[index]);
        }
        float localS = 0.0f;
        for (int j = 0; j < dimsize; j++)
        {
            int index = tid + j * stride;
            localS += exp(input[index] - localM);
        }
        for (int j = 0; j < dimsize; j++)
        {
            int index = tid + j * stride;
            output[index] = exp(input[index] - localM) / localS;
        }
    }
}

softmax.cu

#include <cub/block/block_reduce.cuh>

struct __align__(8) DataMaxSum
{                  // update the global max and sum, store the
                   // output at max_tmp and sum_tmp
    float max_tmp; // store max
    float sum_tmp; // store sum
};
__device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a,
                                                    DataMaxSum b)
{
    bool a_bigger = (a.max_tmp > b.max_tmp);
    DataMaxSum bigger = a_bigger ? a : b;
    DataMaxSum smaller = a_bigger ? b : a;
    bigger.sum_tmp = bigger.sum_tmp +
                     smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp);

    return bigger;
}
template <typename T, int BLOCK_DIM>
__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel(
    T *__restrict input, T *__restrict output, int dimsize,
    int stride)
{ // if set axis = 1, inputShape=[I,J,K,S]
  // tid = i(JKS) + j(KS) + k(S) + s

    // blockDim.x = othersize = size/dimsize = IKS
    // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s

    int tid =
        blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
                                  dimsize; // now, tid = i(JKS) + k(S) + s;

    DataMaxSum dms_partial;
    dms_partial.max_tmp = -__FLT_MAX__;
    dms_partial.sum_tmp = 0.0f;
    DataMaxSum dms_input;
    int remain = dimsize % BLOCK_DIM;
    int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread

    if (threadIdx.x < remain)
    {
        for (int ind = 0; ind < step; ind++)
        {
            dms_input.max_tmp =
                input[tid + (threadIdx.x * step + ind) * stride];

            dms_input.sum_tmp = 1.0f;
            dms_partial =
                reduce_dms_op(dms_partial,
                              dms_input); // reduce the data to one block
        }
    }
    else
    {
        for (int ind = 0; ind < step - 1; ind++)
        {
            dms_input.max_tmp =
                input[tid + (remain * step +
                             (threadIdx.x - remain) * (step - 1) + ind) *
                                stride];

            dms_input.sum_tmp = 1.0f;
            dms_partial =
                reduce_dms_op(dms_partial,
                              dms_input); // reduce the data to one block
        }
    }

    typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;
    __shared__ DataMaxSum dms_total;
    DataMaxSum dms_block =
        BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
    if (threadIdx.x ==
        0)
    { // must set threadIdx.x = 0 write the output to memory
        dms_total = dms_block;
    }
    __syncthreads();
    //-----------------
    if (threadIdx.x < remain)
    {
        for (int ind = 0; ind < step; ind++)
        {

            output[tid + (threadIdx.x * step + ind) * stride] =
                __expf(static_cast<float>(
                           input[tid + (threadIdx.x * step + ind) * stride]) -
                       dms_total.max_tmp) *
                __fdividef(1.0F, dms_total.sum_tmp);
        }
    }
    else
    {
        for (int ind = 0; ind < step - 1; ind++)
        {

            output[tid +
                   (remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
                       stride] =
                __expf(static_cast<float>(
                           input[tid +
                                 (remain * step +
                                  (threadIdx.x - remain) * (step - 1) + ind) *
                                     stride]) -
                       dms_total.max_tmp) *
                __fdividef(1.0F, dms_total.sum_tmp);
        }
    }
}

template <typename T, int BLOCK_DIM, int numPerThread>
__global__ void
_blockSoftmaxKernel(T *__restrict input, T *__restrict output,
                    int dimsize,
                    int stride)
{ // if set axis = 1, inputShape=[I,J,K,S]
  // tid = i(JKS) + j(KS) + k(S) + s

    // blockDim.x = othersize = size/dimsize = IKS
    // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s

    int tid =
        blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) *
                                  dimsize; // now, tid = i(JKS) + k(S) + s;
    int remain = dimsize % BLOCK_DIM;
    int step = (dimsize - remain) / BLOCK_DIM + 1; // step <= numPerThread
    float dataPerThread[numPerThread];

    DataMaxSum dms_partial;
    dms_partial.max_tmp = -__FLT_MAX__;
    dms_partial.sum_tmp = 0.0f;
    DataMaxSum dms_input;
    if (threadIdx.x < remain)
    {
        for (int ind = 0; ind < step; ind++)
        {
            dataPerThread[ind] =
                input[tid + (threadIdx.x * step + ind) * stride];
            dms_input.max_tmp = dataPerThread[ind];
            dms_input.sum_tmp = 1.0f;
            dms_partial =
                reduce_dms_op(dms_partial,
                              dms_input); // reduce the data to one block
        }
    }
    else
    {
        for (int ind = 0; ind < step - 1; ind++)
        {
            dataPerThread[ind] =
                input[tid + (remain * step +
                             (threadIdx.x - remain) * (step - 1) + ind) *
                                stride];
            dms_input.max_tmp = dataPerThread[ind];
            dms_input.sum_tmp = 1.0f;
            dms_partial =
                reduce_dms_op(dms_partial,
                              dms_input); // reduce the data to one block
        }
    }

    typedef cub::BlockReduce<DataMaxSum, BLOCK_DIM> BlockReduce;
    __shared__ typename BlockReduce::TempStorage temp_storage;
    __shared__ DataMaxSum dms_total;
    DataMaxSum dms_block =
        BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op);
    if (threadIdx.x ==
        0)
    { // must set threadIdx.x = 0 write the output to memory
        dms_total = dms_block;
    }
    __syncthreads();
    //-----------------
    if (threadIdx.x < remain)
    {
        for (int ind = 0; ind < step; ind++)
        {
            output[tid + (threadIdx.x * step + ind) * stride] =
                __expf(dataPerThread[ind] - dms_total.max_tmp) *
                __fdividef(1.0F, dms_total.sum_tmp);
        }
    }
    else
    {
        for (int ind = 0; ind < step - 1; ind++)
        {
            output[tid +
                   (remain * step + (threadIdx.x - remain) * (step - 1) + ind) *
                       stride] =
                __expf(dataPerThread[ind] - dms_total.max_tmp) *
                __fdividef(1.0F, dms_total.sum_tmp);
        }
    }
}

template <typename T>
struct SumOp
{
    __device__ __forceinline__ T operator()(const T &a, const T &b) const
    {
        return a + b;
    }
};

template <typename T>
struct MaxOp
{
    __device__ __forceinline__ T operator()(const T &a, const T &b) const
    {
        return max(a, b);
    }
};
template <template <typename> class ReductionOp, typename T,
          int thread_group_width>
__inline__ __device__ T WarpAllReduce(T val)
{
    for (int mask = thread_group_width / 2; mask > 0; mask /= 2)
    {
        val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
    }
    return val;
}

template <typename T, int BLOCK_DIM_x, int BLOCK_DIM_y, int numPerThreadx>
__global__ void _warpSoftmaxKernel(T *__restrict input, T *__restrict output,
                                   int othersize, int dimsize, int stride)
{
    int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;

    int tid = otherIdx % stride + (otherIdx - otherIdx % stride) * dimsize;
    float dataPerThreadx[numPerThreadx];
    if (otherIdx < othersize)
    {

        __shared__ float max_total[BLOCK_DIM_y];
        __shared__ float sum_total[BLOCK_DIM_y];
        float max_data = -__FLT_MAX__;

        for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++)
        {
            dataPerThreadx[ph] =
                input[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride];
            max_data = max(max_data, dataPerThreadx[ph]);
        }

        max_data = WarpAllReduce<MaxOp, float, BLOCK_DIM_x>(max_data);

        if (threadIdx.x == 0)
            max_total[threadIdx.y] = max_data;

        //--------------------------------------------
        float sum_data = 0.0f;

        for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++)
        {
            dataPerThreadx[ph] =
                __expf(dataPerThreadx[ph] - max_total[threadIdx.y]);
            sum_data += dataPerThreadx[ph];
        }

        sum_data = WarpAllReduce<SumOp, float, BLOCK_DIM_x>(sum_data);

        if (threadIdx.x == 0)
            sum_total[threadIdx.y] = sum_data;

        //--------------------------------------------

        for (int ph = 0; threadIdx.x + ph * BLOCK_DIM_x < dimsize; ph++)
        {
            output[tid + (threadIdx.x + ph * BLOCK_DIM_x) * stride] =
                dataPerThreadx[ph] * __fdividef(1.0F, sum_total[threadIdx.y]);
        }
    }
}
//-----------------
template <typename T>
void softmaxLaunch(T const *input, T *output, int size, int dimsize, int stride)
{

    int num_blocks = size / dimsize;

    if (dimsize > 1024 * 128)
    {

        int BLOCK_DIM = 1024;
        _blockSoftmaxKernel<T, 1024>
            <<<num_blocks, BLOCK_DIM>>>((T *)input, (T *)output, dimsize, stride);
    }
    else if (dimsize > 1024 * 64)
    {

        int BLOCK_DIM = 1024;
        _blockSoftmaxKernel<T, 1024, 128>
            <<<num_blocks, BLOCK_DIM>>>((T *)input, (T *)output, dimsize, stride);
    }
    else if (dimsize > 1024 * 32)
    {

        int BLOCK_DIM = 1024;
        _blockSoftmaxKernel<T, 1024, 64>
            <<<num_blocks, BLOCK_DIM>>>((T *)input, (T *)output, dimsize, stride);
    }
    else if (dimsize > 1024 * 16)
    {

        int BLOCK_DIM = 1024;
        _blockSoftmaxKernel<T, 1024, 32>
            <<<num_blocks, BLOCK_DIM>>>((T *)input, (T *)output, dimsize, stride);
    }
    else if (dimsize > 1024 * 4)
    {

        int BLOCK_DIM = 1024;
        _blockSoftmaxKernel<T, 1024, 16>
            <<<num_blocks, BLOCK_DIM>>>((T *)input, (T *)output, dimsize, stride);
    }
    else if (dimsize > 1024)
    {

        int BLOCK_DIM = 1024;
        _blockSoftmaxKernel<T, 1024, 4>
            <<<num_blocks, BLOCK_DIM>>>((T *)input, (T *)output, dimsize, stride);
    }
    else if (dimsize > 31)
    {
        int BLOCK_DIM_x = 32;
        int BLOCK_DIM_y = 32;
        int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
        dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
        dim3 grid_dim(num_block_x, 1, 1);

        _warpSoftmaxKernel<T, 32, 32, 32>
            <<<grid_dim, block_dim>>>((T *)input, (T *)output, num_blocks, dimsize, stride);
    }
    else if (dimsize > 15)
    {
        int BLOCK_DIM_x = 16;
        int BLOCK_DIM_y = 64;
        int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
        dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
        dim3 grid_dim(num_block_x, 1, 1);

        _warpSoftmaxKernel<T, 16, 64, 2>
            <<<grid_dim, block_dim>>>((T *)input, (T *)output, num_blocks, dimsize, stride);
    }
    else if (dimsize > 7)
    {
        int BLOCK_DIM_x = 8;
        int BLOCK_DIM_y = 128;
        int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
        dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
        dim3 grid_dim(num_block_x, 1, 1);

        _warpSoftmaxKernel<T, 8, 128, 2>
            <<<grid_dim, block_dim>>>((T *)input, (T *)output, num_blocks, dimsize, stride);
    }
    else
    {
        int BLOCK_DIM_x = 4;
        int BLOCK_DIM_y = 256;
        int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y;
        dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1);
        dim3 grid_dim(num_block_x, 1, 1);

        _warpSoftmaxKernel<T, 4, 256, 2>
            <<<grid_dim, block_dim>>>((T *)input, (T *)output, num_blocks, dimsize, stride);
    }
    cudaDeviceSynchronize();
}
extern "C" void softmax_nv_f32(float *input, float *output, int size, int dimsize, int stride)
{
    softmaxLaunch<float>(input, output, size, dimsize, stride);
}

softmax.mlu

#include "bang.h"
#include "cnrt.h"
const int NRAM_MAX_SIZE = 1024 * 256;
__nram__ char nram_buffer[NRAM_MAX_SIZE];

template<typename T>
__mlu_global__ void softmaxKernelAxis_e(T *destination, T const *source, int othersize, int dimsize, int dimS) {// axis = -1
  
  const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 8;
  const int wSize = 128 / sizeof(T);
  
  const int maxNum = SRC_MAX_SIZE/sizeof(T);
  __nram__ T srcMax[2];
  if(dimsize >= maxNum){
    T *src = (T *)nram_buffer;
    T *destSum = src + 3 * maxNum;
    T *destSumFinal = destSum + maxNum;
    T destOldMax;
    T destNewMax;

    int remain = dimsize % maxNum;
    int repeat = (dimsize - remain)/maxNum;

    int otherRemain = othersize % taskDim;
    int stepEasy = (othersize - otherRemain) / taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < otherRemain ? stepHard : stepEasy);
    int startHard = taskId * stepHard;
    int startEasy = otherRemain * stepHard + (taskId - otherRemain) * stepEasy;
    int indStart = (taskId < otherRemain ? startHard : startEasy);
    source = source + indStart * dimsize;
    destination = destination + indStart * dimsize;
    
    for(int s = 0; s < step; s++){
      
      destOldMax = -INFINITY;
      destNewMax = -INFINITY;
      __bang_write_zero(destSum, maxNum);
      for(int i = 0; i < repeat + 1; i++){
        if(i < repeat){
          __memcpy_async(src + i % 2 * maxNum, source + s * dimsize + i * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
        }
        if(i > 0){
          __bang_argmax(srcMax, src + (i - 1) % 2 * maxNum, maxNum);
          if(destNewMax < srcMax[0]){
            destNewMax = srcMax[0];
          }
          __bang_sub_scalar(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, destNewMax, maxNum);
          __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);
          if(i > 1){
            __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
          }
          __bang_add(destSum, destSum, src + (i - 1) % 2 * maxNum, maxNum);
          destOldMax = destNewMax;
        }
        __sync_all_ipu();
      }
      //------------
      if(remain){
        __bang_write_value(src, maxNum, -INFINITY);
        __memcpy(src, source + s * dimsize + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM);
        
        __bang_argmax(srcMax, src, maxNum);
        if(destNewMax < srcMax[0]){
          destNewMax = srcMax[0];
        }
        
        __bang_sub_scalar(src, src, destNewMax, maxNum);
        __bang_active_exp_less_0(src, src, maxNum);
        if(repeat > 0){
          __bang_mul_scalar(destSum, destSum, exp(destOldMax - destNewMax), maxNum);
        }
        __bang_add(destSum, destSum, src, maxNum);
        destOldMax = destNewMax;
      }
      //--------------
      //--------------------------------
      
      int segNum = maxNum / wSize;
      for(int strip = segNum/2; strip > 0; strip = strip / 2){
        for(int i = 0; i < strip ; i++){
          __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize);
        } 
      }
      __bang_reduce_sum(destSumFinal, destSum, wSize);
      
      //-----------
      T globalSumInv = 1.0/destSumFinal[0];
      for(int i = 0; i < repeat + 2; i++){
        if(i < repeat){
          __memcpy_async(src + i % 3 * maxNum, source + s * dimsize + i * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
        }
        if(i > 0 && i < repeat + 1){
          __bang_sub_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, destNewMax, maxNum); 
          __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);
          __bang_mul_scalar(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, globalSumInv, maxNum);
        }
        if(i > 1){
          __memcpy_async(destination + s * dimsize + (i - 2) * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM);
        }
        __sync_all_ipu();
        
      }
      if(remain){
        __bang_write_value(src, maxNum, destNewMax);
        __memcpy(src, source + s * dimsize + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM);
        __bang_sub_scalar(src, src, destNewMax, maxNum);
        __bang_active_exp_less_0(src, src, maxNum);
        __bang_mul_scalar(src, src, globalSumInv, maxNum);
        __memcpy(destination + s * dimsize + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM);
      }
    }
    
  }
  else{
    int multiple = maxNum / dimsize;//一个src可以处理multiple个otherIdx
    int size = taskDim * multiple;//所有core可以处理size个otherIdx
    int remain = othersize % size;// remain < taskDim * multiple
    int repeat = (othersize - remain) / size;

    int remainT = remain % taskDim;
    int stepEasy = (remain - remainT) / taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < remainT ? stepHard : stepEasy);
    int startHard = taskId * stepHard * dimsize;//前面remainT个taskId分配到stepHard个dimsize
    int startEasy = remainT * stepHard * dimsize + (taskId - remainT) * stepEasy * dimsize;
    int indStart = (taskId < remainT ? startHard : startEasy);
    
    //-----------------------------------------allocate memory
    T* src = (T *)nram_buffer;//src[maxNum]
    T* tmp = src + 3 * maxNum;//tmp[dimS]
    T* destSum = tmp + dimS;//destSum[dimS],dimS >= max(dimsize, wSize), dimS = pow(2,K) ,pow(2,K - 1) < dimsize
    T* destSumFinal = destSum + wSize;
    //-----------------------------------------
    //printf("taskId:%d, repeat:%d, step:%d, repeatDim:%d, indstart:%d, %d\n", taskId, repeat, step, repeatDim, indStart, indStart * dimsize);
    int tid;
    __bang_write_value(tmp, dimS, -INFINITY);
    __bang_write_zero(destSum, dimS);
    if(repeat >= 2){
        int s = 0;
        tid = s * size * dimsize + taskId * multiple * dimsize;
        __memcpy(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM);
        s = 1;
        tid = s * size * dimsize + taskId * multiple * dimsize;
        __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM);

        // compute ------------------------
        for(int j = 0; j < multiple; j++){
            
            __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
            __bang_argmax(srcMax, tmp, dimS);
            __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
            __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
        }
        __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum);
        for(int j = 0; j < multiple; j++){
            
            __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
            __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM);
            int segNum = dimS / wSize;//Starting numerical summation
            for(int strip = segNum/2; strip > 0; strip = strip / 2){
                for(int i = 0; i < strip ; i++){
                    __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize);
                } 
            }
            __bang_reduce_sum(destSumFinal, destSum, wSize);
            T globalSumInv = 1.0/destSumFinal[0];
            __bang_mul_scalar(tmp, tmp, globalSumInv, dimS);

            __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
        }
        // compute ------------------------

        for(int s = 2; s < repeat; s++){
            tid = (s - 2) * size * dimsize + taskId * multiple * dimsize;
            __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM);

            tid = s * size * dimsize + taskId * multiple * dimsize;
            __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM);
            
            // compute ------------------------
            
            __bang_argmax(srcMax, src + (s - 1) %3 * maxNum, maxNum);//这一段特殊处理取全局max
            __bang_sub_scalar(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, srcMax[0], maxNum);
            __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum);
            
            for(int j = 0; j < multiple; j++){
                __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
                __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM);
                int segNum = dimS / wSize;//Starting numerical summation
                for(int strip = segNum/2; strip > 0; strip = strip / 2){
                    for(int i = 0; i < strip ; i++){
                        __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize);
                    } 
                }
                __bang_reduce_sum(destSumFinal, destSum, wSize);
                T globalSumInv = 1.0/destSumFinal[0];
                __bang_mul_scalar(tmp, tmp, globalSumInv, dimS);

                __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
            }
            // compute ------------------------
        }
        s = repeat;
        tid = (s - 2) * size * dimsize + taskId * multiple * dimsize;
        __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM);
        // compute ------------------------
        for(int j = 0; j < multiple; j++){
            
            __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
            __bang_argmax(srcMax, tmp, dimS);
            __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
            __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
        }
        __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum);
        for(int j = 0; j < multiple; j++){
            
            __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
            __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM);
            int segNum = dimS / wSize;//Starting numerical summation
            for(int strip = segNum/2; strip > 0; strip = strip / 2){
                for(int i = 0; i < strip ; i++){
                    __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize);
                } 
            }
            __bang_reduce_sum(destSumFinal, destSum, wSize);
            T globalSumInv = 1.0/destSumFinal[0];
            __bang_mul_scalar(tmp, tmp, globalSumInv, dimS);

            __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
        }
        // compute ------------------------
        s = repeat + 1;
        tid = (s - 2) * size * dimsize + taskId * multiple * dimsize;
        __memcpy(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM);
    }
    else{
        for(int s = 0; s < repeat + 2; s++){
            if(s < repeat){
                tid = s * size * dimsize + taskId * multiple * dimsize;
                __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * dimsize * sizeof(T), GDRAM2NRAM);
            }
            if(s > 0 && s < repeat + 1){
                // compute ------------------------
            
                for(int j = 0; j < multiple; j++){
                    __memcpy(tmp, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
                    __bang_argmax(srcMax, tmp, dimS);
                    __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
                    __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
                }
                __bang_active_exp_less_0(src + (s - 1) %3 * maxNum, src + (s - 1) %3 * maxNum, maxNum);
                
                for(int j = 0; j < multiple; j++){
                    __memcpy(destSum, src + (s - 1) %3 * maxNum + j * dimsize, dimsize * sizeof(T), NRAM2NRAM);
                    __memcpy(tmp, destSum, dimsize * sizeof(T), NRAM2NRAM);
                    int segNum = dimS / wSize;//Starting numerical summation
                    for(int strip = segNum/2; strip > 0; strip = strip / 2){
                        for(int i = 0; i < strip ; i++){
                            __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize);
                        } 
                    }
                    __bang_reduce_sum(destSumFinal, destSum, wSize);
                    T globalSumInv = 1.0/destSumFinal[0];
                    __bang_mul_scalar(tmp, tmp, globalSumInv, dimS);

                    __memcpy(src + (s - 1) %3 * maxNum + j * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM);
                }
                // compute ------------------------
            }
            if(s > 1){
                tid = (s - 2) * size * dimsize + taskId * multiple * dimsize;
                __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * dimsize * sizeof(T), NRAM2GDRAM);
            }
            __sync_all_ipu();//如果maxNum比较小,此时访存时间>计算时间,无法延迟
        }
    }
    if(step){
      tid = repeat * size * dimsize + indStart;
      __memcpy(src, source + tid, step * dimsize * sizeof(T), GDRAM2NRAM);
      for(int s = 0; s < step; s++){//Step targets parts of othersize that cannot be divided by multiple * dimsize
        __bang_write_zero(destSum, dimS);
        
        __bang_write_value(tmp, dimS, -INFINITY);
        __memcpy(tmp, src + s * dimsize, dimsize * sizeof(T), NRAM2NRAM);
        
        __bang_argmax(srcMax, tmp, dimS);
        
        __bang_sub_scalar(tmp, tmp, srcMax[0], dimS);
        
        __bang_active_exp_less_0(tmp, tmp, dimS);
        __memcpy(destSum, tmp, dimsize * sizeof(T), NRAM2NRAM);
        
        int segNum = dimS / wSize;
        for(int strip = segNum/2; strip > 0; strip = strip / 2){
          for(int i = 0; i < strip ; i++){
            __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize);
          }
        }
        __bang_reduce_sum(destSumFinal, destSum, wSize);
        
        T globalSumInv = 1.0/destSumFinal[0];
        __bang_mul_scalar(tmp, tmp, globalSumInv, dimS);
        __memcpy(src + s * dimsize, tmp, dimsize * sizeof(T), NRAM2NRAM); 
      } 
      __memcpy(destination + tid, src, step * dimsize * sizeof(T), NRAM2GDRAM);
    }
    
  }
}
template<typename T>
__mlu_global__ void softmaxKernelAxis_s(T *destination, T const *source, T *tmpGdram, int othersize, int dimsize, int stride) {// axis = 0
  
  const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 8;
  

  const int maxNum = SRC_MAX_SIZE/sizeof(T);
  if(othersize > taskDim * maxNum){
    //-----------------------------------------allocate memory
    T* src = (T *)nram_buffer;// src[3 * maxNum]
    T* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
    T* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
    T* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
    //-----------------------------------------
    int remain = othersize % taskDim;
    int stepEasy = (othersize - remain)/taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element
    int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);
    int remainNram = step%maxNum;
    int repeat = (step - remainNram)/maxNum;
    
    for(int j = 0; j < repeat; j++){
      __bang_write_value(tmpNewMax, maxNum, -INFINITY);
      __bang_write_zero(tmpSum, maxNum);
      for(int i = 0; i < dimsize + 1; i++){
        if(i < dimsize){
          __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
        }
        if(i > 0){
          __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);//Continuously updating the maximum value
          __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
          if(i > 1){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM
        }
        __sync_all_ipu();
      } 
      __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
      
      for(int i = 0; i < dimsize + 2; i++){
        if(i < dimsize){
          __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
        }
        if(i > 0 && i < dimsize + 1){
          __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
          __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
        }
        if(i > 1){
          __memcpy_async(destination + (i - 2) * stride + indStart + j * maxNum, src + (i - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM);
        }
        __sync_all_ipu();
      } 
    }
    if(remainNram){
      __bang_write_value(tmpNewMax, maxNum, -INFINITY);
      __bang_write_zero(tmpSum, maxNum);
      __bang_write_zero(src, 3 * maxNum);
      
      for(int i = 0; i < dimsize + 1; i++){
        if(i < dimsize){
          __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(T), GDRAM2NRAM);
        }
        if(i > 0){
          __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
          __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
          if(i > 1){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM
        }
        __sync_all_ipu();
      } 
      
      __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
      //Start exponential transformation and write back to GDRAM
      
      for(int i = 0; i < dimsize + 2; i++){
        if(i < dimsize){
          __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart + repeat * maxNum, remainNram * sizeof(T), GDRAM2NRAM);
        }
        if(i > 0 && i < dimsize + 1){
          __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
          __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
        }
        if(i > 1){
          __memcpy_async(destination + (i - 2) * stride + indStart + repeat * maxNum, src + (i - 2) % 3 * maxNum, remainNram * sizeof(T), NRAM2GDRAM);
        }
        __sync_all_ipu();
      }     
    }
  }
  else if (othersize > maxNum && othersize <= taskDim * maxNum){
    T* src = (T *)nram_buffer;// src[3 * maxNum]
    T* tmpSum = src + 3 * maxNum;//tmpSum[maxNum]
    T* tmpNewMax = src + 4 * maxNum;//tmpNewMax[maxNum]
    T* tmpOldMax = src + 5 * maxNum;//tmpOldMax[maxNum]
    //-----------------------------------------
    int remain = othersize % taskDim;
    int stepEasy = (othersize - remain)/taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < remain ? stepHard : stepEasy);//The first part of taskId handles an additional element
    int indStart = (taskId < remain ? taskId * stepHard : remain * stepHard + (taskId - remain) * stepEasy);
    
    __bang_write_value(tmpNewMax, maxNum, -INFINITY);
    __bang_write_zero(tmpSum, maxNum);
    __bang_write_zero(src, 3 * maxNum);
    
    for(int i = 0; i < dimsize + 1; i++){
      if(i < dimsize){
        __memcpy_async(src + i % 2 * maxNum, source + i * stride + indStart, step * sizeof(T), GDRAM2NRAM);
      }
      if(i > 0){
        __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum, maxNum);
        __bang_sub(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, maxNum);//exp(x - M)
        if(i > 1){
          __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
          __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
          __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)
        }
        __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum, maxNum);//sum += exp(x - M)
        __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM
      }
      __sync_all_ipu();
    } 
    
    __bang_active_reciphp(tmpSum, tmpSum, maxNum);//compute 1/sum
    //Start exponential transformation and write back to GDRAM
    
    for(int i = 0; i < dimsize + 2; i++){
      if(i < dimsize){
        __memcpy_async(src + i % 3 * maxNum, source + i * stride + indStart, step * sizeof(T), GDRAM2NRAM);
      }
      if(i > 0 && i < dimsize + 1){
        __bang_sub(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpNewMax, maxNum);//x - M
        __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, maxNum);//exp(x - M)
        __bang_mul(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, tmpSum, maxNum);
      }
      if(i > 1){
        __memcpy_async(destination + (i - 2) * stride + indStart, src + (i - 2) % 3 * maxNum, step * sizeof(T), NRAM2GDRAM);
      }
      __sync_all_ipu();
    }               
  }
  else{
    
    int multiple = maxNum / othersize;
    int size = taskDim * multiple;
    int remain = dimsize % size;
    int repeat = (dimsize - remain) / size;

    int remainT = remain % taskDim;
    int stepEasy = (remain - remainT) / taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < remainT ? stepHard : stepEasy);
    int indStart = (taskId < remainT ? taskId * stepHard : remainT * stepHard + (taskId - remainT) * stepEasy);
    
    T* src = (T *)nram_buffer;// src[3 * maxNum]
    T* tmpSum = src + 3 * maxNum;//tmpSum[othersize]
    T* tmpNewMax = tmpSum + othersize;//tmpNewMax[othersize]
    T* tmpOldMax = tmpNewMax + othersize;//tmpOldMax[othersize]
    T* tmpGlobal = tmpOldMax + othersize;
    __bang_write_value(tmpNewMax, othersize, -INFINITY);
    
    __bang_write_zero(tmpSum, othersize);
    __bang_write_zero(src, 3 * maxNum);
    
    for(int i = 0; i < repeat + 1; i++){
      if (i < repeat){
        __memcpy_async(src + (i % 2) * maxNum, source + (i * size + taskId * multiple) * stride, multiple * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize
      }
      if(i > 0){
        for(int m = 0; m < multiple; m++){
          __bang_maxequal(tmpNewMax, tmpNewMax, src + (i - 1) % 2 * maxNum + m * othersize, othersize);
        }
        for(int m = 0; m < multiple; m++){
          __bang_sub(src + (i - 1) % 2 * maxNum + m * othersize, src + (i - 1) % 2 * maxNum + m * othersize, tmpNewMax, othersize);//x - M
        }
        __bang_active_exp_less_0(src + (i - 1) % 2 * maxNum, src + (i - 1) % 2 * maxNum, multiple * othersize);//exp(x - M)
        if(i > 1){
          __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);//oldM = oldM - newM
          __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);//exp(oldM - newM)
          __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);      //sum = sum * exp(oldM - newM)
        }
        for(int m = 0; m < multiple; m++){
          __bang_add(tmpSum, tmpSum, src + (i - 1) % 2 * maxNum + m * othersize, othersize);
        }
        __memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(T), NRAM2NRAM);
      }
      __sync_all_ipu();
    }
    
    if(step) {
      __memcpy(src, source + repeat * size * stride + indStart * stride, step * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize
      
      for(int m = 0; m < step; m++){
        __bang_maxequal(tmpNewMax, tmpNewMax, src + m * othersize, othersize);
      }
      for(int m = 0; m < step; m++){
        __bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);//x - M
      }
      __bang_active_exp_less_0(src, src, step * othersize);//exp(x - M)
      if(repeat > 0){
        __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);//oldM = oldM - newM
        __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);//exp(oldM - newM)
        __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);      //sum = sum * exp(oldM - newM)
      }
      for(int m = 0; m < step; m++){
        __bang_add(tmpSum, tmpSum, src + m * othersize, othersize);
      }
      __memcpy(tmpOldMax, tmpNewMax, othersize * sizeof(T), NRAM2NRAM);
    }
    //----------------
    if(repeat > 0 || dimsize >= taskDim){
      __memcpy(tmpGdram + taskId * othersize, tmpNewMax, othersize * sizeof(T), NRAM2GDRAM);
      __sync_all();
      __bang_write_value(tmpNewMax, othersize, -INFINITY);
      for(int id = 0; id < taskDim; id++){
        __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM);
        __bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize);
      }
      __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);
      __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);
      __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);
      __memcpy(tmpGdram + taskId * othersize, tmpSum, othersize * sizeof(T), NRAM2GDRAM);
      __sync_all();
      __bang_write_zero(tmpSum, othersize);
      for(int id = 0; id < taskDim; id++){
        __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM);
        __bang_add(tmpSum, tmpSum, tmpGlobal, othersize);
      }
      __bang_active_recip_greater_1(tmpSum, tmpSum, othersize);
    }
    else{
      __memcpy(tmpGdram + taskId * othersize, tmpNewMax, othersize * sizeof(T), NRAM2GDRAM);
      __sync_all();
      __bang_write_value(tmpNewMax, othersize, -INFINITY);
      for(int id = 0; id < dimsize; id++){
        __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM);
        __bang_maxequal(tmpNewMax, tmpNewMax, tmpGlobal, othersize);
      }
      __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, othersize);
      __bang_active_exp_less_0(tmpOldMax, tmpOldMax, othersize);
      __bang_mul(tmpSum, tmpSum, tmpOldMax, othersize);
      __memcpy(tmpGdram + taskId * othersize, tmpSum, othersize * sizeof(T), NRAM2GDRAM);
      __sync_all();
      __bang_write_zero(tmpSum, othersize);
      for(int id = 0; id < dimsize; id++){
        __memcpy(tmpGlobal, tmpGdram + id * othersize, othersize * sizeof(T), GDRAM2NRAM);
        __bang_add(tmpSum, tmpSum, tmpGlobal, othersize);
      }
      __bang_active_recip_greater_1(tmpSum, tmpSum, othersize);
    }
    
    //-------------------
    for(int i = 0; i < repeat + 2; i++){
      if(i < repeat){
        __memcpy_async(src + (i % 3) * maxNum, source + (i * size + taskId * multiple) * stride, multiple * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize
      }
      if(i > 0){
        for(int m = 0; m < multiple; m++){
          __bang_sub(src + (i - 1) % 3 * maxNum + m * othersize, src + (i - 1) % 3 * maxNum + m * othersize, tmpNewMax, othersize);
        }
        __bang_active_exp_less_0(src + (i - 1) % 3 * maxNum, src + (i - 1) % 3 * maxNum, multiple * othersize);
        for(int m = 0; m < multiple; m++){
          __bang_mul(src + (i - 1) % 3 * maxNum + m * othersize, src + (i - 1) % 3 * maxNum + m * othersize, tmpSum, othersize);
        }
      }
      if (i > 1){
        __memcpy_async(destination + ((i - 2) * size + taskId * multiple) * stride, src + (i - 2) % 3 * maxNum, multiple * othersize * sizeof(T), NRAM2GDRAM);
      }
      __sync_all_ipu();
    }
    if(step) {
      __memcpy(src, source + repeat * size * stride + indStart * stride, step * othersize * sizeof(T), GDRAM2NRAM);//stride=othersize
      for(int m = 0; m < step; m++){
        __bang_sub(src + m * othersize, src + m * othersize, tmpNewMax, othersize);
      }
      __bang_active_exp_less_0(src, src, step * othersize);
      for(int m = 0; m < step; m++){
        __bang_mul(src + m * othersize, src + m * othersize, tmpSum, othersize);
      }
      __memcpy(destination + repeat * size * stride + indStart * stride, src, step * othersize * sizeof(T), NRAM2GDRAM);
    }
  }
}
template<typename T>
__mlu_global__ void softmaxKernelAxis_m(T *destination, T const *source, int frontsize, int dimsize, int stride, int strideS) {
  // 0<axis<dim -1 
  
  const int SRC_MAX_SIZE = NRAM_MAX_SIZE / 8;
  
  const int maxNum = SRC_MAX_SIZE/sizeof(T);
  if(stride >= maxNum){
    //-----------------------------------------allocate memory
    T *src = (T *)nram_buffer;
    T *tmpSum = src + 3 * maxNum;
    T *tmpNewMax = tmpSum + maxNum;
    T *tmpOldMax = tmpNewMax + maxNum;
    //-----------------------------------------
    int remain = stride % maxNum;
    int repeat = (stride - remain) / maxNum;
    
    for(int ind = taskId; ind < frontsize; ind += taskDim){
      int frontIdx = ind * dimsize * stride;
      for(int j = 0; j < repeat; j++){
        __bang_write_value(tmpNewMax, maxNum, -INFINITY);
        __bang_write_zero(tmpSum, maxNum);
        //__bang_write_zero(src, maxNum);
        for(int i = 0; i < dimsize; i++){
          __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);//Continuously updating the maximum value
          __bang_sub(src, src, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
          if(i > 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);//sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM
        }
        __bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum
        //Start exponential transformation and write back to GDRAM
        __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
        __memcpy(destination + (dimsize - 1) * stride + frontIdx + j * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM);
        for(int i = 0; i < dimsize - 1; i++){
          __memcpy(src, source + frontIdx + i * stride + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM);
          __bang_sub(src, src, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
          __bang_mul(src, src, tmpSum, maxNum);
          __memcpy(destination + frontIdx + i * stride + j * maxNum, src, maxNum * sizeof(T), NRAM2GDRAM);
        } 
      }
      if(remain){
        
        __bang_write_value(tmpNewMax, maxNum, -INFINITY);
        __bang_write_zero(tmpSum, maxNum);
        __bang_write_value(src, maxNum, -INFINITY);
        for(int i = 0; i < dimsize; i++){
          __memcpy(src, source + frontIdx + i * stride + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, src, maxNum);
          __bang_sub(src, src, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
          if(i > 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, maxNum);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, maxNum);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, maxNum);      //sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, src, maxNum);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, maxNum * sizeof(T), NRAM2NRAM);//oldM = newM
        }
        //-------------------
        __bang_active_reciphp(tmpSum, tmpSum, maxNum);//计算1/sum
        //Start exponential transformation and write back to GDRAM
        __bang_mul(src, src, tmpSum, maxNum);//The data stored in the src at the end of the loop above can be utilized
        __memcpy(destination + (dimsize - 1) * stride + frontIdx + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM);
        for(int i = 0; i < dimsize - 1; i++){
          __memcpy(src, source + i * stride + frontIdx + repeat * maxNum, remain * sizeof(T), GDRAM2NRAM);
          __bang_sub(src, src, tmpNewMax, maxNum);//x - M
          __bang_active_exp_less_0(src, src, maxNum);//exp(x - M)
          __bang_mul(src, src, tmpSum, maxNum);
          __memcpy(destination + i * stride + frontIdx + repeat * maxNum, src, remain * sizeof(T), NRAM2GDRAM);
        } 
        //---------------------
      }
    }
  }
  else if(stride < maxNum && dimsize * stride >= maxNum){
   
    //-----------------------------------------allocate memory
    T* src = (T *)nram_buffer;
    T* tmp = src + 3 * maxNum;
    T* tmpOldMax = tmp + strideS;
    T* tmpNewMax = tmpOldMax + strideS;
    T* tmpSum = tmpNewMax + strideS;
    //-----------------------------------------
    int multiple = maxNum / stride;
    int size = multiple * stride;//The maximum amount of data that can be stored in an SRC
    int remain = dimsize % multiple;//If it cannot be divisible, this part of the data needs special processing
    int repeat = (dimsize - remain) / multiple;//The total number of loops required to load the entire dimsize

    int taskRemain = frontsize % taskDim;
    int stepEasy = (frontsize - taskRemain) / taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < taskRemain ? stepHard : stepEasy);//The number of frontsize processed per taskId
    int indStart = (taskId < taskRemain ? taskId * stepHard : taskRemain * stepHard + (taskId - taskRemain) * stepEasy);
    source = source + indStart * dimsize * stride;
    destination = destination + indStart * dimsize * stride;
    //printf("maxNum:%d, dimsize * stride:%d, multiple:%d, size:%d, repeat:%d,remain:%d\n",maxNum, dimsize * stride, multiple, size, repeat,remain);
    for(int ind = 0; ind < step; ind++){
      int frontIdx = ind * dimsize * stride;
      
      __bang_write_value(tmpNewMax, strideS, -INFINITY);//Must be initialized to negative infinity
      __bang_write_value(tmp, strideS, -INFINITY);//Must be initialized to negative infinity
      __bang_write_zero(tmpSum, strideS);//Must be initialized to zero
      
      for(int j = 0; j < repeat + 1; j++){
        if(j < repeat){
          __memcpy_async(src + j % 2 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(T), GDRAM2NRAM);
        }
        if(j > 0){
          for(int m = 0; m < multiple; m++){
            __memcpy(tmp, src + (j - 1) % 2 * maxNum + m * stride, stride * sizeof(T), NRAM2NRAM);
            
            __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);//Although the stream S stream section after tmpNewMax is 0, there is no need to write back to GDRAM, which does not affect the result
            
            __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
            __bang_active_exp_less_0(tmp, tmp, strideS);
            if(j != 1 || m != 0){
              __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
              __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
              __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
            }
            __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
            
            __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM
          }
        }
        __sync_all_ipu();
      }
      
      if(remain){
        __memcpy(src, source + frontIdx + repeat * multiple * stride, remain * stride * sizeof(T), GDRAM2NRAM);
        for(int m = 0; m < remain; m++){
          __memcpy(tmp, src + m * stride, stride * sizeof(T), NRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//The stripe S stripe section after tmp is 0
          __bang_active_exp_less_0(tmp, tmp, strideS);
          if(repeat != 0 || m != 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);//sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM
        }
      }
      
      //At this point, tmpNewMax stores the maximum value of the data corresponding to a fixed frontIdx and bedsize, while tmpSum stores the corresponding value sum
      
      __bang_active_reciphp(tmpSum, tmpSum, strideS);
      
      if(remain){
        for(int m = 0; m < remain; m++){
          __memcpy(tmp, src + m * stride, stride * sizeof(T), NRAM2NRAM);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);
          __bang_active_exp_less_0(tmp, tmp, strideS);
          __bang_mul(tmp, tmp, tmpSum, strideS);
          __memcpy(destination + frontIdx + repeat * multiple * stride + m * stride, tmp, stride * sizeof(T), NRAM2GDRAM);
        }
        
      }
      for(int j = 0 ; j < repeat + 2; j++){
        if(j < repeat){
          __memcpy_async(src + j % 3 * maxNum, source + frontIdx + j * multiple * stride, size * sizeof(T), GDRAM2NRAM);
        }
        if(j > 0 && j < repeat + 1){
          for(int m = 0; m < multiple; m++){
            __memcpy(tmp, src + (j - 1) % 3 * maxNum + m * stride, stride * sizeof(T), NRAM2NRAM);
            
            __bang_sub(tmp, tmp, tmpNewMax, strideS);
            __bang_active_exp_less_0(tmp, tmp, strideS);
            __bang_mul(tmp, tmp, tmpSum, strideS);
            __memcpy(src + (j - 1) % 3 * maxNum + m * stride, tmp, stride * sizeof(T), NRAM2NRAM);
          }
        }
        if(j > 1){
          __memcpy_async(destination + frontIdx + (j - 2) * multiple * stride, src + (j - 2) % 3 * maxNum, size * sizeof(T), NRAM2GDRAM);
        }
        __sync_all_ipu();
      }
    }
  }
  else if(dimsize * stride < maxNum){
    //-----------------------------------------allocate memory
    T* src = (T *)nram_buffer;
    T* tmp = src + 3 * maxNum;
    T* tmpOldMax = tmp + strideS;
    T* tmpNewMax = tmpOldMax + strideS;
    T* tmpSum = tmpNewMax + strideS;
    //-----------------------------------------
    int behindsize = dimsize * stride;
    int multiple = maxNum / behindsize;//Represents the amount that a maxNum can share in frontsize
    
    int remainF = frontsize % (taskDim * multiple);
    int remainT = remainF % taskDim;
    int stepEasy = (remainF - remainT) / taskDim;
    int stepHard = stepEasy + 1;
    int step = (taskId < remainT ? stepHard : stepEasy);
    int taskRepeat = (frontsize - remainF) / (taskDim * multiple);
    //At this point, corresponding to frontsize, the amount of data processed by each taskId is taskRepeat * multiple+step
    int startHard = taskId * (taskRepeat * multiple + stepHard);
    int startEasy = remainT * (taskRepeat * multiple + stepHard) + (taskId - remainT) * (taskRepeat * multiple + stepEasy);
    int indStart = (taskId < remainT ? startHard: startEasy);
    source = source + indStart * behindsize;//indStart * behindsize Indicates the offset corresponding to different taskIds
    destination = destination + indStart * behindsize;
    int tid;
    for(int s = 0; s < taskRepeat + 2; s++){
      if(s < taskRepeat){
        tid = s * multiple * behindsize;
        __memcpy_async(src + s % 3 * maxNum, source + tid, multiple * behindsize * sizeof(T), GDRAM2NRAM);
      }
      if(s > 0 && s < taskRepeat + 1){
        for(int m = 0; m < multiple; m++){
          __bang_write_zero(tmpSum, strideS);
          __bang_write_value(tmp, strideS, -INFINITY);
          __bang_write_value(tmpNewMax, strideS, -INFINITY);
          for(int i = 0; i < dimsize; i++){
            __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM);
            __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
            __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
            __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
            if(i > 0){
              __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
              __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
              __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);      //sum = sum * exp(oldM - newM)
            }
            __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
            __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM
          }
          __bang_active_reciphp(tmpSum, tmpSum, strideS);
          __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
          
          __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(T), NRAM2NRAM);
          for(int i = 0; i < dimsize - 1; i++){
            __memcpy(tmp, src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM);
            __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
            __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
            __bang_mul(tmp, tmp, tmpSum, strideS);
            
            __memcpy(src + (s - 1) % 3 * maxNum + m * behindsize + i * stride, tmp, stride * sizeof(T), NRAM2NRAM);
          }
        }
      }
      if(s > 1){
        tid = (s - 2) * multiple * behindsize;
        __memcpy_async(destination + tid, src + (s - 2) % 3 * maxNum, multiple * behindsize * sizeof(T), NRAM2GDRAM);
      }
      __sync_all_ipu();
    }
    //__bang_printf("taskId:%d, multiple:%d, taskRepeat:%d, step:%d, indStart:%d\n",taskId, multiple, taskRepeat, step, indStart * behindsize);
    if(step){
      tid = taskRepeat * multiple * behindsize; 
      __memcpy(src, source + tid, step * behindsize * sizeof(T), GDRAM2NRAM);
      for(int m = 0; m < step; m++){
        __bang_write_zero(tmpSum, strideS);
        __bang_write_value(tmp, strideS, -INFINITY);
        __bang_write_value(tmpNewMax, strideS, -INFINITY);
        for(int i = 0; i < dimsize; i++){
          __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM);
          __bang_maxequal(tmpNewMax, tmpNewMax, tmp, strideS);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
          __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
          if(i > 0){
            __bang_sub(tmpOldMax, tmpOldMax, tmpNewMax, strideS);//oldM = oldM - newM
            __bang_active_exp_less_0(tmpOldMax, tmpOldMax, strideS);//exp(oldM - newM)
            __bang_mul(tmpSum, tmpSum, tmpOldMax, strideS);      //sum = sum * exp(oldM - newM)
          }
          __bang_add(tmpSum, tmpSum, tmp, strideS);//sum += exp(x - M)
          __memcpy(tmpOldMax, tmpNewMax, stride * sizeof(T), NRAM2NRAM);//oldM = newM
        }
        //__bang_printf("max:%.2f,%.2f, sum:%.2f,sum:%.2f\n", tmpNewMax[0], tmpNewMax[1], tmpSum[0], tmpSum[0]);
        __bang_active_reciphp(tmpSum, tmpSum, strideS);
        __bang_mul(tmp, tmp, tmpSum, strideS);//The data stored in tmp at the end of the loop above can be utilized
        //__memcpy(destination + tid + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(T), NRAM2GDRAM);
        __memcpy(src + m * behindsize + (dimsize - 1) * stride, tmp, stride * sizeof(T), NRAM2NRAM);
        for(int i = 0; i < dimsize - 1; i++){
          __memcpy(tmp, src + m * behindsize + i * stride, stride * sizeof(T), NRAM2NRAM);
          __bang_sub(tmp, tmp, tmpNewMax, strideS);//x - M
          __bang_active_exp_less_0(tmp, tmp, strideS);//exp(x - M)
          __bang_mul(tmp, tmp, tmpSum, strideS);
          //__memcpy(destination + tid + m * behindsize + i * stride, tmp, stride * sizeof(T), NRAM2GDRAM);
          __memcpy(src + m * behindsize + i * stride, tmp, stride * sizeof(T), NRAM2NRAM);
        }
      }
      __memcpy(destination + tid, src, step * behindsize * sizeof(T), NRAM2GDRAM);
    }
  }
    
}

template<typename T>
void softmaxUnion1(cnrtQueue_t queue, T const *input, T *output, int othersize, int dimsize, int frontsize, int stride, int axis, int ndim) {
    
    const int wSize = 128 / sizeof(T);

    cnrtDim3_t k_dim;
    cnrtFunctionType_t k_type;

    k_dim.x = 16;
    k_dim.y = 1;
    k_dim.z = 1;
    k_type = CNRT_FUNC_TYPE_UNION1;
    
    int taskNum = k_dim.x * k_dim.y * k_dim.z;

    if(axis == ndim - 1){
        int dimS;
        float mi = log2(dimsize);
        if (floor(mi) == mi)
        {
            dimS = dimsize;
        }
        else
        {
            dimS = static_cast<int>(pow(2, floor(mi) + 1));
        }
        if (dimS < wSize)
        {
            dimS = wSize;
        }
        softmaxKernelAxis_e<T><<<k_dim, k_type, queue>>>(output, input, othersize, dimsize, dimS);
    }
    else if(axis == 0){
        T *tmpGdram;
        CNRT_CHECK(cnrtMalloc((void **)&tmpGdram, taskNum * othersize * sizeof(T)));
        softmaxKernelAxis_s<T><<<k_dim, k_type, queue>>>(output, input, tmpGdram, othersize, dimsize, stride);
        cnrtFree(tmpGdram);
    }
    else{
        float mi = log2(stride);
        int strideS;
        if(floor(mi) == mi){
            strideS = stride;
        }
        else{
            strideS = static_cast<int>(pow(2,floor(mi) + 1));
        }
        softmaxKernelAxis_m<T><<<k_dim, k_type, queue>>>(output, input, frontsize, dimsize, stride, strideS);
    }
        
    cnrtQueueSync(queue);
}



extern "C" void softmax_bang_f32(float const *input, float *output, int othersize, int dimsize, int frontsize, int stride, int axis, int ndim) {
    cnrtQueue_t queue;
    CNRT_CHECK(cnrtSetDevice(0));
    CNRT_CHECK(cnrtQueueCreate(&queue));
    softmaxUnion1<float>(queue, input, output, othersize, dimsize, frontsize, stride, axis, ndim);
    CNRT_CHECK(cnrtQueueDestroy(queue));
}

test

针对不同平台的算子,在python脚本这边也引入–device这个参数来自动选择到底是测试哪个平台的kernel。

performance.py

import torch
import time
import logging
def CudaProfile(*function_with_args):
    times = 20
    for _ in range(times):
        for func, args in function_with_args:
            func(*args)
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(times):
        for func, args in function_with_args:
            func(*args)
    end_event.record()
    # 等待事件完成
    torch.cuda.synchronize()
    elapsed_time = start_event.elapsed_time(end_event)  # 以毫秒为单位        
    return elapsed_time/times
def CpuProfile(*function_with_args):
    times = 20
    for _ in range(times):
        for func, args in function_with_args:
            func(*args)
    start = time.time()
    for _ in range(times):
        for func, args in function_with_args:
            func(*args)
    
    elapsed_time = time.time() - start  # 以毫秒为单位        
    return 1000 * elapsed_time/times
def BangProfile(*function_with_args):
    times = 20
    for _ in range(times):
        for func, args in function_with_args:
            func(*args)
    start = time.time()
    for _ in range(times):
        for func, args in function_with_args:
            func(*args)
    
    elapsed_time = time.time() - start  # 以毫秒为单位        
    return 1000 * elapsed_time/times
def logBenchmark(baseline, time):
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    msg = "Pytorch: " + str(baseline) + " ms, kernel: " + str(time) + " ms "
    percentage = "{:.2f}%".format(abs(baseline - time)/baseline * 100)
    if baseline >= time:
        logging.info(msg + "\033[32m" + "[-" + percentage + "]" +"\033[0m")
    else:
        logging.info(msg + "\033[31m" + "[+" + percentage + "]" +"\033[0m")

test_attention.py

import torch
import ctypes
import numpy as np
import torch.nn.functional as F
import argparse

import performance
# 添加上一层目录到模块搜索路径
import sys
import os

# 定义函数参数类型
def funAttention(Q, K, V): 
    return torch.softmax(Q@K.t(), dim = 1)@V

lib_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.././build/lib/libmy_library.so')
lib = ctypes.CDLL(lib_path)


def test(test_shape, test_dtype, device):
    print(
        f"Testing Attention on {device} with x_shape:{test_shape} , dtype:{test_dtype}"
    )
    N, d = test_shape[0], test_shape[1]
    Q = torch.randn(test_shape, device=device, dtype=torch.float32, requires_grad=False) 
    K = torch.randn(test_shape, device=device, dtype=torch.float32, requires_grad=False)
    V = torch.randn(test_shape, device=device, dtype=torch.float32, requires_grad=False)
    # 创建输出张量
    attHPC = torch.zeros(test_shape, device = device, dtype = torch.float32)

    Q_ptr = ctypes.cast(Q.data_ptr(), ctypes.POINTER(ctypes.c_float))
    K_ptr = ctypes.cast(K.data_ptr(), ctypes.POINTER(ctypes.c_float))
    V_ptr = ctypes.cast(V.data_ptr(), ctypes.POINTER(ctypes.c_float))
    attHPC_ptr = ctypes.cast(attHPC.data_ptr(), ctypes.POINTER(ctypes.c_float))
    if device == "cuda":
        lib.attention_nv_f32.argtypes = [
        ctypes.POINTER(ctypes.c_float),
        ctypes.POINTER(ctypes.c_float),
        ctypes.POINTER(ctypes.c_float),
        ctypes.c_int,
        ctypes.c_int,
        ctypes.POINTER(ctypes.c_float)
        ]

        torch_flash_time = performance.CudaProfile((funAttention, (Q, K, V)))
        # 调用 C 函数
        custom_attention_time = performance.CudaProfile((
            lib.attention_nv_f32,
            (Q_ptr, K_ptr, V_ptr, N, d, attHPC_ptr)
        ))
        performance.logBenchmark(torch_flash_time, custom_attention_time)

    # 将结果转换回 PyTorch 张量以进行比较
    tmpa = funAttention(Q, K, V).to('cpu').numpy().reshape(-1,1).flatten()
    tmpb = attHPC.to('cpu').numpy().reshape(-1,1).flatten()
    atol = max(abs(tmpa - tmpb))

    rtol = atol / max(abs(tmpb) + 1e-8)


    print("absolute error:%.4e"%(atol))
    print("relative error:%.4e"%(rtol))

# 解析命令行参数
parser = argparse.ArgumentParser(description="Test softmax on different devices.")
parser.add_argument('--device', choices=['cpu', 'cuda'], required=True, help="Device to run the tests on.")
args = parser.parse_args()    

test_cases = [
        # x_shape, axis
        ((128, 128), torch.float32, 'cuda'),
        ((256, 128), torch.float32, 'cuda'), 
        ((1024, 128), torch.float32, 'cuda'), 
        ((1024, 1024), torch.float32, 'cuda'), 
]
filtered_test_cases = [
    (test_shape, test_dtype, device)
    for test_shape, test_dtype, device in test_cases
    if device == args.device
]

for test_shape,test_dtype, device in filtered_test_cases:
    test(test_shape, test_dtype, device)

test_softmax.py

import torch
import ctypes
import numpy as np
import torch.nn.functional as F
import argparse

import performance
# 添加上一层目录到模块搜索路径
import sys
import os

lib_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.././build/lib/libmy_library.so')
lib = ctypes.CDLL(lib_path)

def dataPrew(test_shape, test_axis):
    ndim = len(test_shape)
    dimsize = test_shape[test_axis]
    size = 1
    stride = 1
    for i in range(ndim - 1, -1, -1):
        size *= test_shape[i]
    for i in range(ndim - 1, -1, -1):
        
        if(test_axis == i):
            break
        stride *= test_shape[i]
    return size, stride, dimsize
def test(test_shape, test_axis, test_dtype, device):
    print(
        f"Testing Softmax on {device} with x_shape:{test_shape} , axis:{test_axis} ,dtype:{test_dtype}"
    )
    size, stride, dimsize = dataPrew(test_shape, test_axis)
    Q = torch.randn(test_shape, device=device, dtype=test_dtype, requires_grad=False)
    Q_output = torch.zeros(test_shape, device=device, dtype=torch.float32) 

    input_ptr = ctypes.cast(Q.data_ptr(), ctypes.POINTER(ctypes.c_float))
    output_ptr = ctypes.cast(Q_output.data_ptr(), ctypes.POINTER(ctypes.c_float))
    if device == "cuda":
        torch_softmax_time = performance.CudaProfile((torch.softmax, (Q, test_axis)))  # 以毫秒为单位
        lib.softmax_nv_f32.argtypes = [
            ctypes.POINTER(ctypes.c_float),
            ctypes.POINTER(ctypes.c_float),
            ctypes.c_int,
            ctypes.c_int,
            ctypes.c_int
        ]
        custom_softmax_time = performance.CudaProfile((lib.softmax_nv_f32, (input_ptr, output_ptr, size, dimsize, stride)))  # 以毫秒为单位
    if device == "cpu":
        torch_softmax_time = performance.CpuProfile((torch.softmax, (Q, test_axis)))  # 以毫秒为单位
        lib.softmax_cpu_f32.argtypes = [
            ctypes.POINTER(ctypes.c_float),
            ctypes.POINTER(ctypes.c_float),
            ctypes.c_int,
            ctypes.c_int,
            ctypes.c_int
        ]
        custom_softmax_time = performance.CpuProfile((lib.softmax_cpu_f32, (input_ptr, output_ptr, size, dimsize, stride)))  # 以毫秒为单位
    if device == "mlu":
        torch_softmax_time = performance.BangProfile((torch.softmax, (Q, test_axis)))  # 以毫秒为单位
        ndim = len(test_shape)
        frontsize = 1
        othersize = 1
        for s in range(ndim - 1, -1, -1):
            if (s < test_axis): 
                frontsize *= test_shape[s]
            if (s != test_axis):
                othersize *= test_shape[s];
        
        lib.softmax_bang_f32.argtypes = [
            ctypes.POINTER(ctypes.c_float),
            ctypes.POINTER(ctypes.c_float),
            ctypes.c_int,
            ctypes.c_int,
            ctypes.c_int,
            ctypes.c_int,
            ctypes.c_int,
            ctypes.c_int
        ]
        custom_softmax_time = performance.BangProfile((lib.softmax_bang_f32, (input_ptr, output_ptr, othersize, dimsize, frontsize, stride, test_axis, ndim)))  # 以毫秒为单位
    performance.logBenchmark(torch_softmax_time, custom_softmax_time)
    # 将结果转换回 PyTorch 张量以进行比较
    tmpa = torch.softmax(Q, test_axis).to('cpu').reshape(-1,1).numpy().flatten()
    tmpb = Q_output.to('cpu').reshape(-1,1).numpy().flatten()

    atol = max(abs(tmpa - tmpb))

    rtol = atol / max(abs(tmpb) + 1e-8)


    print("absolute error:%.4e"%(atol))
    print("relative error:%.4e"%(rtol))

# 解析命令行参数
parser = argparse.ArgumentParser(description="Test softmax on different devices.")
parser.add_argument('--device', choices=['cpu', 'cuda', 'mlu'], required=True, help="Device to run the tests on.")
args = parser.parse_args()    

test_cases = [
        # x_shape, axis
        ((700, 1200, 24), 0, torch.float32, 'cuda'),
        ((700, 1200, 24), 1, torch.float32, 'cuda'), 
        ((700, 1200, 24), 2, torch.float32, 'cuda'), 

        ((700, 1200, 24), 0, torch.float32, 'mlu'),
        ((700, 1200, 24), 1, torch.float32, 'mlu'), 
        ((700, 1200, 24), 2, torch.float32, 'mlu'), 

        ((70, 12, 24), 0, torch.float32, 'cpu'),
        ((70, 12, 24), 1, torch.float32, 'cpu'), 
        ((70, 12, 24), 2, torch.float32, 'cpu'), 
         
]
filtered_test_cases = [
    (test_shape, test_axis, test_dtype, device)
    for test_shape, test_axis, test_dtype, device in test_cases
    if device == args.device
]
if args.device == 'mlu':
    import torch_mlu
# 执行过滤后的测试用例
for test_shape, test_axis, test_dtype, device in filtered_test_cases:
    test(test_shape, test_axis, test_dtype, device)

CMakeLists.txt

这里重点介绍CMakeLists.txt的编写,由于计算机平台的不同,有点计算机只有CPU,有的计算机包含CUDA和CPU,但是国产芯片往往只有CPU和自己的显卡,比如说寒武纪芯片只有CPU和MLU,为此必须要在CMakeLists.txt文件设置对应的开关option,用户可以自己选择到底在什么平台编译对应的代码。比如说本人搭建的这个测试程序,在寒武纪芯片上开启对应的选项可以编译寒武纪核函数(对应的.mlu)以及CPU端kernel(对应的.cpp)
在这里插入图片描述
除此之外,为了应付可能出现的30多个算子,这里使用正则表达式和GLOB自动匹配src目录下对应的kernel。CMakeLists.txt上对应的条件判断如下所示:
在这里插入图片描述

cmake_minimum_required(VERSION 3.16)

project(MyCUDAProject)

# 查找 Python 库
find_package(Python3 REQUIRED)
include_directories(${Python3_INCLUDE_DIRS})

# 设置 CUDA 编译选项
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")

# 添加选项控制不同的编译方式
option(USE_CUDA "Enable CUDA compilation" OFF)
option(USE_BANG "Enable BANG compilation" OFF)
option(USE_CPU "Enable CPU-only compilation" OFF)

# 查找源文件
file(GLOB CPP_SOURCE_FILES "src/**/*.cpp")
file(GLOB CUDA_SOURCE_FILES "src/**/*.cu")
file(GLOB BANG_SOURCE_FILES "src/**/*.mlu")

# 根据选项决定编译哪些源文件
if(USE_CUDA)
    message(STATUS "CUDA build enabled.")
    enable_language(CXX)
    enable_language(CUDA)
    list(APPEND ALL_SOURCE_FILES ${CUDA_SOURCE_FILES} ${CPP_SOURCE_FILES})
    add_library(my_library SHARED ${ALL_SOURCE_FILES})# 创建库或可执行文件
elseif(USE_BANG)
    message(STATUS "BANG build enabled.")
    
    set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}/lib")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -fPIC -std=c++11 -pthread -pipe")
    set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} ${CMAKE_CXX_FLAGS} -O3")
    set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS_RELEASE} -Wl,--gc-sections -fPIC")

    # check `NEUWARE_HOME` env
    if(NOT DEFINED ENV{NEUWARE_HOME})  
        set(NEUWARE_HOME "/usr/local/neuware" CACHE PATH "Path to NEUWARE installation")  
    else()  
        set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE PATH "Path to NEUWARE installation" FORCE)  
    endif()
      # check `NEUWARE_HOME` env
    message(${NEUWARE_HOME})
    if(EXISTS ${NEUWARE_HOME})
        include_directories("${NEUWARE_HOME}/include")
        link_directories("${NEUWARE_HOME}/lib64")
        link_directories("${NEUWARE_HOME}/lib")
        set(NEUWARE_ROOT_DIR "${NEUWARE_HOME}")
    else()
        message(FATAL_ERROR "NEUWARE directory cannot be found, refer README.md to prepare NEUWARE_HOME environment.")
    endif()

    # setup cmake search path
    set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
    "${CMAKE_SOURCE_DIR}/cmake"
    "${NEUWARE_HOME}/cmake"
    "${NEUWARE_HOME}/cmake/modules"
    )

    # include FindBANG.cmake and check cncc
    find_package(BANG)
    if(NOT BANG_FOUND)
        message(FATAL_ERROR "BANG cannot be found.")
    elseif (NOT BANG_CNCC_EXECUTABLE)
        message(FATAL_ERROR "cncc not found, please ensure cncc is in your PATH env or set variable BANG_CNCC_EXECUTABLE from cmake. Otherwise you should check path used by find_program(BANG_CNCC_EXECUTABLE) in FindBANG.cmake")
    endif()

    # setup cncc flags
    set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -fPIC -Wall -Werror -std=c++11 -pthread")
    set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3")
    set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS}" "--bang-mlu-arch=mtp_592")

    list(APPEND ALL_SOURCE_FILES ${BANG_SOURCE_FILES} ${CPP_SOURCE_FILES})
    bang_add_library(my_library SHARED ${ALL_SOURCE_FILES})# 创建库或可执行文件
    target_link_libraries(my_library cnnl cnnl_extra cnrt cndrv)
elseif(USE_CPU)
    message(STATUS "CPU-only build enabled.")
    enable_language(CXX)
    list(APPEND ALL_SOURCE_FILES ${CPP_SOURCE_FILES})
    add_library(my_library SHARED ${ALL_SOURCE_FILES})# 创建库或可执行文件
else()
    message(FATAL_ERROR "No valid compilation mode specified. Please enable USE_CUDA, USE_BANG, or USE_CPU.")
endif()




# 设置编译选项
target_compile_features(my_library PUBLIC cxx_std_11)

# 链接 Python 库
target_link_libraries(my_library PRIVATE ${Python3_LIBRARIES})

# 指定输出目录
set_target_properties(my_library PROPERTIES
    LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/lib)

run.sh

mkdir build
cd build
cmake ../ -DUSE_CPU=ON
make

根据run.sh里面选择DUSE_CPU=ON可以默认使用CPU平台,修改这个选项可以自动识别CUDA或者是MLU芯片。
寒武纪芯片测试结果如下所示:
在这里插入图片描述


原文地址:https://blog.csdn.net/forrestguang/article/details/143746091

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!