find_package(OpenCL REQUIRED)
find_package(Python3 REQUIRED)

set(TARGET_NAME ggml-opencl)

ggml_add_backend_library(${TARGET_NAME}
                         ggml-opencl.cpp
                         ../../include/ggml-opencl.h)
target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES})
target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS})

if (GGML_OPENCL_PROFILING)
    message(STATUS "OpenCL profiling enabled (increases CPU overhead)")
    add_compile_definitions(GGML_OPENCL_PROFILING)
endif ()

add_compile_definitions(GGML_OPENCL_SOA_Q)
add_compile_definitions(GGML_OPENCL_TARGET_VERSION=${GGML_OPENCL_TARGET_VERSION})

if (GGML_OPENCL_USE_ADRENO_KERNELS)
    message(STATUS "OpenCL will use matmul kernels optimized for Adreno")
    add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS)
endif ()

if (GGML_OPENCL_EMBED_KERNELS)
    add_compile_definitions(GGML_OPENCL_EMBED_KERNELS)

    set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py")
    file(MAKE_DIRECTORY     "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")

    target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
endif ()

function(ggml_opencl_add_kernel KNAME)
    set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
    set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)

    if (GGML_OPENCL_EMBED_KERNELS)
        message(STATUS "opencl: embedding kernel ${KNAME}")

        # Python must be accessible from command line
        add_custom_command(
            OUTPUT ${KERN_HDR}
            COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} ${KERN_SRC} ${KERN_HDR}
            DEPENDS ${KERN_SRC} ${EMBED_KERNEL_SCRIPT}
            COMMENT "Generate ${KERN_HDR}"
        )

        target_sources(${TARGET_NAME} PRIVATE ${KERN_HDR})
    else ()
        message(STATUS "opencl: adding kernel ${KNAME}")
        configure_file(${KERN_SRC} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${KNAME}.cl COPYONLY)
    endif ()
endfunction()

set(GGML_OPENCL_KERNELS
    ggml-opencl
    ggml-opencl_mm
    ggml-opencl_cvt
    ggml-opencl_gemv_noshuffle
    ggml-opencl_gemv_noshuffle_general
    ggml-opencl_mul_mat_Ab_Bi_8x4
    ggml-opencl_transpose_16
    ggml-opencl_transpose_32
    ggml-opencl_transpose_32_16
    ggml-opencl_im2col
)

foreach (K ${GGML_OPENCL_KERNELS})
    ggml_opencl_add_kernel(${K})
endforeach()
