/*
 * Copyright (C) 2025 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 *
 */

#include "shared/source/command_stream/preemption.h"
#include "shared/source/command_stream/preemption.inl"
#include "shared/source/xe3p_core/hw_cmds_base.h"

namespace NEO {

using GfxFamily = Xe3pCoreFamily;

template <>
void PreemptionHelper::programInterfaceDescriptorDataPreemption<GfxFamily>(INTERFACE_DESCRIPTOR_DATA<GfxFamily> *idd, PreemptionMode preemptionMode) {
    if (preemptionMode == PreemptionMode::MidThread) {
        idd->setThreadPreemption(true);
    } else {
        idd->setThreadPreemption(false);
    }
}

template <>
void PreemptionHelper::programInterfaceDescriptorDataPreemption<GfxFamily>(GfxFamily::INTERFACE_DESCRIPTOR_DATA_2 *idd, PreemptionMode preemptionMode) {
    if (preemptionMode == PreemptionMode::MidThread) {
        idd->setThreadPreemption(true);
    } else {
        idd->setThreadPreemption(false);
    }
}

template <>
void PreemptionHelper::programCsrBaseAddressCmd<GfxFamily>(LinearStream &preambleCmdStream, const GraphicsAllocation *preemptionCsr) {
    using STATE_CONTEXT_DATA_BASE_ADDRESS = typename GfxFamily::STATE_CONTEXT_DATA_BASE_ADDRESS;

    auto stateContextBaseAddressCmd = preambleCmdStream.getSpaceForCmd<STATE_CONTEXT_DATA_BASE_ADDRESS>();
    STATE_CONTEXT_DATA_BASE_ADDRESS cmd = GfxFamily::cmdInitStateContextDataBaseAddress;
    cmd.setContextDataBaseAddress(preemptionCsr->getGpuAddressToPatch());
    *stateContextBaseAddressCmd = cmd;
}

template <>
void PreemptionHelper::programCsrBaseAddress<GfxFamily>(LinearStream &preambleCmdStream, Device &device, const GraphicsAllocation *preemptionCsr) {
    bool debuggingEnabled = device.getDebugger() != nullptr;
    bool isMidThreadPreemption = device.getPreemptionMode() == PreemptionMode::MidThread;
    if (isMidThreadPreemption || debuggingEnabled) {
        programCsrBaseAddressCmd<GfxFamily>(preambleCmdStream, preemptionCsr);
    }
}

template <>
size_t PreemptionHelper::getRequiredPreambleSize<GfxFamily>(const Device &device) {
    using STATE_CONTEXT_DATA_BASE_ADDRESS = typename GfxFamily::STATE_CONTEXT_DATA_BASE_ADDRESS;
    bool debuggingEnabled = device.getDebugger() != nullptr;
    if ((device.getPreemptionMode() == PreemptionMode::MidThread) || debuggingEnabled) {
        return sizeof(STATE_CONTEXT_DATA_BASE_ADDRESS);
    }
    return 0u;
}

#include "shared/source/command_stream/preemption_xe2_and_later.inl"

template void PreemptionHelper::programStateSip<GfxFamily>(LinearStream &preambleCmdStream, Device &device, OsContext *context);
template size_t PreemptionHelper::getRequiredStateSipCmdSize<GfxFamily>(Device &device, bool isRcs);
template void PreemptionHelper::programStateSipCmd<GfxFamily>(LinearStream &preambleCmdStream, GraphicsAllocation *sipAllocation, bool useFullAddress);
} // namespace NEO
