All Articles

カスタムプロバイダーで AMSI スキャン要求を処理する

今回は AMSI の概要と動作についてカスタムアプリケーションから AMSI スキャン要求を発行する に引き続き、AMSI インターフェースを使用したスキャン要求を処理するプロバイダーのサンプルコードを読んでいくことにします。

参考:Windows-classic-samples/Samples/AmsiProvider at main · microsoft/Windows-classic-samples

もくじ

プロバイダーの登録

AmsiProvider のサンプルコードをビルドした AmsiProvider.dll は regsvr32 AmsiProvider.dll コマンドでシステムに登録できます。

このコマンドは、DllRegisterServer 関数を呼び出します。

参考:DllRegisterServer の実装 - Win32 apps | Microsoft Learn

サンプルコードでは DllRegisterServer 関数は以下のように実装されており、AMSI プロバイダーとして DLL を登録するために必要ないくつかのレジストリ操作を行っています。

HRESULT SetKeyStringValue(_In_ HKEY key, _In_opt_ PCWSTR subkey, _In_opt_ PCWSTR valueName, _In_ PCWSTR stringValue)
{
    LONG status = RegSetKeyValue(key, subkey, valueName, REG_SZ, stringValue, (wcslen(stringValue) + 1) * sizeof(wchar_t));
    return HRESULT_FROM_WIN32(status);
}


STDAPI DllRegisterServer()
{
    wchar_t modulePath[MAX_PATH];
    if (GetModuleFileName(g_currentModule, modulePath, ARRAYSIZE(modulePath)) >= ARRAYSIZE(modulePath))
    {
        return E_UNEXPECTED;
    }

    // Create a standard COM registration for our CLSID.
    // The class must be registered as "Both" threading model
    // and support multithreaded access.
    wchar_t clsidString[40];
    if (StringFromGUID2(__uuidof(SampleAmsiProvider), clsidString, ARRAYSIZE(clsidString)) == 0)
    {
        return E_UNEXPECTED;
    }

    wchar_t keyPath[200];
    HRESULT hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Classes\\CLSID\\%ls", clsidString);
    if (FAILED(hr)) return hr;

    hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, nullptr, L"SampleAmsiProvider");
    if (FAILED(hr)) return hr;

    hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Classes\\CLSID\\%ls\\InProcServer32", clsidString);
    if (FAILED(hr)) return hr;

    hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, nullptr, modulePath);
    if (FAILED(hr)) return hr;

    hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, L"ThreadingModel", L"Both");
    if (FAILED(hr)) return hr;

    // Register this CLSID as an anti-malware provider.
    hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Microsoft\\AMSI\\Providers\\%ls", clsidString);
    if (FAILED(hr)) return hr;

    hr = SetKeyStringValue(HKEY_LOCAL_MACHINE, keyPath, nullptr, L"SampleAmsiProvider");
    if (FAILED(hr)) return hr;

    return S_OK;
}

この関数は、以下のレジストリキー操作を行います。

  1. サンプルプロバイダーの UUID である {2E5D8A62-77F9-4F7B-A90C-2744820139B2} を使用し、HKEY_LOCAL_MACHINE\SOFTWARE\Classes\CLSID\{2E5D8A62-77F9-4F7B-A90C-2744820139B2} に「SampleAmsiProvider」というプロバイダー名を登録します。
  2. HKEY_LOCAL_MACHINE\SOFTWARE\Classes\CLSID\{2E5D8A62-77F9-4F7B-A90C-2744820139B2}\InProcServer32 を追加し、プロバイダー DLL のファイルパスを登録します。
  3. さらに、同キーの直下に文字列型の ThreadingModel を追加し、値を「Both」に設定します。

image-20250701163500895

  1. 最後に、HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\AMSI\Providers 直下にサンプルプロバイダーの UUID をキーとして登録し、プロバイダー名「SampleAmsiProvider」を登録します。

image-20250701163932598

こうして登録したプロバイダーのアンインストールを行う場合は、regsvr32 /u AmsiProvider.dll コマンドを実行します。

DllUnregisterServer 関数は、上記で作成したレジストリキーの削除を担当します。

STDAPI DllUnregisterServer()
{
    wchar_t clsidString[40];
    if (StringFromGUID2(__uuidof(SampleAmsiProvider), clsidString, ARRAYSIZE(clsidString)) == 0)
    {
        return E_UNEXPECTED;
    }

    // Unregister this CLSID as an anti-malware provider.
    wchar_t keyPath[200];
    HRESULT hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Microsoft\\AMSI\\Providers\\%ls", clsidString);
    if (FAILED(hr)) return hr;
    LONG status = RegDeleteTree(HKEY_LOCAL_MACHINE, keyPath);
    if (status != NO_ERROR && status != ERROR_PATH_NOT_FOUND) return HRESULT_FROM_WIN32(status);

    // Unregister this CLSID as a COM server.
    hr = StringCchPrintf(keyPath, ARRAYSIZE(keyPath), L"Software\\Classes\\CLSID\\%ls", clsidString);
    if (FAILED(hr)) return hr;
    status = RegDeleteTree(HKEY_LOCAL_MACHINE, keyPath);
    if (status != NO_ERROR && status != ERROR_PATH_NOT_FOUND) return HRESULT_FROM_WIN32(status);

    return S_OK;
}

DllMain 関数

まずは、DLL の読み込みまたはアンロード時に呼び出されるエントリポイントとなる DllMain 関数から見ていきます。

参考:DllMain エントリ ポイント (Process.h) - Win32 apps | Microsoft Learn

BOOL APIENTRY DllMain(HMODULE module, DWORD reason, LPVOID reserved)
{
    switch (reason)
    {
    case DLL_PROCESS_ATTACH:
        g_currentModule = module;
        DisableThreadLibraryCalls(module);
        TraceLoggingRegister(g_traceLoggingProvider);
        TraceLoggingWrite(g_traceLoggingProvider, "Loaded");
        Module<InProc>::GetModule().Create();
        break;

    case DLL_PROCESS_DETACH:
        Module<InProc>::GetModule().Terminate();
        TraceLoggingWrite(g_traceLoggingProvider, "Unloaded");
        TraceLoggingUnregister(g_traceLoggingProvider);
        break;
    }
    return TRUE;
}

この DLL のロード時にはまず、受け取った DLL モジュールへのハンドルを引数として DisableThreadLibraryCalls 関数を呼び出し、DLL_THREAD_ATTACHDLL_THREAD_DETACH の通知を無効化しています。

これにより、プロセス内でスレッドが生成/終了するたびに不要な呼び出しが発生することを防ぐ利点があるようです。

参考:DisableThreadLibraryCalls 関数 (libloaderapi.h) - Win32 apps | Microsoft Learn

続いて、TraceLoggingRegister 関数により TraceLogging プロバイダーの登録を行います。

これは、ETW トレースロギングを目的としており、このとき引数として与えられた g_traceLoggingProvider にてプロバイダー名が SampleAmsiProvider、GUID が 00604c86-2d25-46d6-b814-cd149bfdf0b3 である プロバイダーが登録されます。

// Define a trace logging provider: 00604c86-2d25-46d6-b814-cd149bfdf0b3
TRACELOGGING_DEFINE_PROVIDER(g_traceLoggingProvider, "SampleAmsiProvider",
    (0x00604c86, 0x2d25, 0x46d6, 0xb8, 0x14, 0xcd, 0x14, 0x9b, 0xfd, 0xf0, 0xb3));

参考:TraceLoggingRegister - Win32 apps | Microsoft Learn

その後、TraceLoggingWrite(g_traceLoggingProvider, "Loaded"); によりイベント名が「Loaded」となるイベントの出力を行います。

参考:TraceLoggingWrite マクロ (traceloggingprovider.h) - Win32 apps | Microsoft Learn

最後に、Module<InProc>::GetModule().Create(); が呼び出されます。

これについては詳しい動作はあまり理解できていないのですが、COM DLL 用に WRL が提供しているモジュールを管理するクラスである Microsoft::WRL::Module を使用してインスタンスの準備を行っているようです。

参考:方法: WRL を使用して従来の COM コンポーネントを作成する | Microsoft Learn

参考:Module クラス | Microsoft Learn

SampleAmsiProvider のオーバーライド

AMSI プロバイダーは実際のところホストにロードされた DLL を実体とする COM サーバであり、IUnknown インターフェースを継承する IAntimalwareProvider インターフェースを公開しています。

参考:IAntimalwareProvider (amsi.h) - Win32 apps | Microsoft Learn

class
    DECLSPEC_UUID("2E5D8A62-77F9-4F7B-A90C-2744820139B2")
    SampleAmsiProvider : public RuntimeClass<RuntimeClassFlags<ClassicCom>, IAntimalwareProvider, FtmBase>
{
public:
    IFACEMETHOD(Scan)(_In_ IAmsiStream* stream, _Out_ AMSI_RESULT* result) override;
    IFACEMETHOD_(void, CloseSession)(_In_ ULONGLONG session) override;
    IFACEMETHOD(DisplayName)(_Outptr_ LPWSTR* displayName) override;

private:
    // We assign each Scan request a unique number for logging purposes.
    LONG m_requestNumber = 0;
};

AMSI プロバイダーの開発者は、IAntimalwareProvider インターフェースの Scan メソッドをオーバーライドすることで独自のコンテンツスキャンを実行できるようになります。

DisplayName メソッドのオーバーライド

AntimalwareProvider::DisplayName メソッドは、単に AMSI プロバイダーの名前を返すメソッドです。

参考:IAntimalwareProvider::DisplayName (amsi.h) - Win32 apps | Microsoft Learn

SampleAmsiProvider では以下の通りオーバーライドされており、Sample AMSI Provider という名前を返すように実装されています。

HRESULT SampleAmsiProvider::DisplayName(_Outptr_ LPWSTR *displayName)
{
    *displayName = const_cast<LPWSTR>(L"Sample AMSI Provider");
    return S_OK;
}

CloseSession メソッドのオーバーライド

IAntimalwareProvider::CloseSession メソッドは AMSI セッションを閉じるメソッドとして定義されていますが、今回のサンプルの場合は単に ETW トレースセッションに Close session を出力する処理のみが定義されています。

参考:IAntimalwareProvider::CloseSession (amsi.h) - Win32 apps | Microsoft Learn

恐らく AMSI スキャン時などに何らかのリソース操作を行う場合などはここでクリーンアップが必要になるのだと思います。

void SampleAmsiProvider::CloseSession(_In_ ULONGLONG session)
{
    TraceLoggingWrite(g_traceLoggingProvider, "Close session",
        TraceLoggingValue(session));
}

Scan メソッドのオーバーライド

コンテンツストリームのスキャンを行う IAntimalwareProvider::Scan メソッドは、今回のサンプルでは以下のようにオーバーライドされています。

参考:IAntimalwareProvider::Scan (amsi.h) - Win32 apps | Microsoft Learn

Scan メソッドといいつつ、実際のところこのサンプルでは受け取ったデータの全てのバイトを XOR 演算した結果を ETW トレースイベントとして出力するだけのメソッドとして実装されており、すべてのスキャン要求に対して AMSI_RESULT_NOT_DETECTED を返します。

HRESULT SampleAmsiProvider::Scan(_In_ IAmsiStream* stream, _Out_ AMSI_RESULT* result)
{
    LONG requestNumber = InterlockedIncrement(&m_requestNumber);
    TraceLoggingWrite(g_traceLoggingProvider, "Scan Start", TraceLoggingValue(requestNumber));

    auto appName = GetStringAttribute(stream, AMSI_ATTRIBUTE_APP_NAME);
    auto contentName = GetStringAttribute(stream, AMSI_ATTRIBUTE_CONTENT_NAME);
    auto contentSize = GetFixedSizeAttribute<ULONGLONG>(stream, AMSI_ATTRIBUTE_CONTENT_SIZE);
    auto session = GetFixedSizeAttribute<ULONGLONG>(stream, AMSI_ATTRIBUTE_SESSION);
    auto contentAddress = GetFixedSizeAttribute<PBYTE>(stream, AMSI_ATTRIBUTE_CONTENT_ADDRESS);

    TraceLoggingWrite(g_traceLoggingProvider, "Attributes",
        TraceLoggingValue(requestNumber),
        TraceLoggingWideString(appName.Get(), "App Name"),
        TraceLoggingWideString(contentName.Get(), "Content Name"),
        TraceLoggingUInt64(contentSize, "Content Size"),
        TraceLoggingUInt64(session, "Session"),
        TraceLoggingPointer(contentAddress, "Content Address"));

    if (contentAddress)
    {
        // The data to scan is provided in the form of a memory buffer.
        auto result = CalculateBufferXor(contentAddress, contentSize);
        TraceLoggingWrite(g_traceLoggingProvider, "Memory xor",
            TraceLoggingValue(requestNumber),
            TraceLoggingValue(result));
    }
    else
    {
        // Provided as a stream. Read it stream a chunk at a time.
        BYTE cumulativeXor = 0;
        BYTE chunk[1024];
        ULONG readSize;
        for (ULONGLONG position = 0; position < contentSize; position += readSize)
        {
            HRESULT hr = stream->Read(position, sizeof(chunk), chunk, &readSize);
            if (SUCCEEDED(hr))
            {
                cumulativeXor ^= CalculateBufferXor(chunk, readSize);
                TraceLoggingWrite(g_traceLoggingProvider, "Read chunk",
                    TraceLoggingValue(requestNumber),
                    TraceLoggingValue(position),
                    TraceLoggingValue(readSize),
                    TraceLoggingValue(cumulativeXor));
            }
            else
            {
                TraceLoggingWrite(g_traceLoggingProvider, "Read failed",
                    TraceLoggingValue(requestNumber),
                    TraceLoggingValue(position),
                    TraceLoggingValue(hr));
                break;
            }
        }
    }

    TraceLoggingWrite(g_traceLoggingProvider, "Scan End", TraceLoggingValue(requestNumber));

    // AMSI_RESULT_NOT_DETECTED means "We did not detect a problem but let other providers scan it, too."
    *result = AMSI_RESULT_NOT_DETECTED;
    return S_OK;
}

このメソッドでは初めに InterlockedIncrement 関数を使用して m_requestNumber を安全にインクリメントした後に Scan Start というログを出力します。

m_requestNumber はロギング目的でスキャンリクエストを識別するためにサンプルプロバイダーで独自に追加されている変数です。

参考:InterlockedIncrement 関数 (winnt.h) - Win32 apps | Microsoft Learn

LONG requestNumber = InterlockedIncrement(&m_requestNumber);
TraceLoggingWrite(g_traceLoggingProvider, "Scan Start", TraceLoggingValue(requestNumber));

続く以下のコードでは、GetStringAttribute 関数と GetFixedSizeAttribute 関数を用いて、受け取ったストリームから AMSI_ATTRIBUTE_APP_NAME などの各種属性情報を取得し、ログに出力する操作を行います。

auto appName = GetStringAttribute(stream, AMSI_ATTRIBUTE_APP_NAME);
auto contentName = GetStringAttribute(stream, AMSI_ATTRIBUTE_CONTENT_NAME);
auto contentSize = GetFixedSizeAttribute<ULONGLONG>(stream, AMSI_ATTRIBUTE_CONTENT_SIZE);
auto session = GetFixedSizeAttribute<ULONGLONG>(stream, AMSI_ATTRIBUTE_SESSION);
auto contentAddress = GetFixedSizeAttribute<PBYTE>(stream, AMSI_ATTRIBUTE_CONTENT_ADDRESS);

TraceLoggingWrite(g_traceLoggingProvider, "Attributes",
    TraceLoggingValue(requestNumber),
    TraceLoggingWideString(appName.Get(), "App Name"),
    TraceLoggingWideString(contentName.Get(), "Content Name"),
    TraceLoggingUInt64(contentSize, "Content Size"),
    TraceLoggingUInt64(session, "Session"),
    TraceLoggingPointer(contentAddress, "Content Address"));

GetStringAttribute 関数と GetFixedSizeAttribute 関数は、いずれも IAmsiStream インターフェースの GetAttribute メソッドを使用して AMSI スキャン要求のメタデータを取得するヘルパー関数のようです。

参考:IAmsiStream::GetAttribute (amsi.h) - Win32 apps | Microsoft Learn

IAmsiStream インターフェースの GetAttribute メソッドは、第 1 引数の AMSI_ATTRIBUTE で指定された属性情報を返す必要があります。

まず、GetStringAttribute 関数では以下の通り実装されており、AMSI_ATTRIBUTE_APP_NAMEAMSI_ATTRIBUTE_CONTENT_NAME の情報を取得します。

GetAttribute 関数は受け取った出力バッファのサイズが十分な大きさでない場合に E_NOT_SUFFICIENT_BUFFER を返し、allocSize に必要なバイトサイズが返されるため、これを利用して再度 GetAttribute 関数を発行することで必要な属性を取得します。

HeapMemPtr<wchar_t> GetStringAttribute(_In_ IAmsiStream* stream, _In_ AMSI_ATTRIBUTE attribute)
{
    HeapMemPtr<wchar_t> result;

    ULONG allocSize;
    ULONG actualSize;
    if (stream->GetAttribute(attribute, 0, nullptr, &allocSize) == E_NOT_SUFFICIENT_BUFFER &&
        SUCCEEDED(result.Alloc(allocSize)) &&
        SUCCEEDED(stream->GetAttribute(attribute, allocSize, reinterpret_cast<PBYTE>(result.Get()), &actualSize)) &&
        actualSize <= allocSize)
    {
        return result;
    }
    return HeapMemPtr<wchar_t>();
}

前回記事 で作成したサンプルクライアントプログラムを使用してメモリスキャン要求を発行してみると、確かに AppName としてサンプルプログラムにハードコードされた Contoso Script Engine v3.4.9999.0 が返されており、Hello, World というスキャン文字列のアドレスを取得していることを確認できます。

image-20250705223137319

一方で、GetFixedSizeAttribute 関数は以下の通り実装されています。

この関数では、呼び出し時に指定された型に合わせて AMSI_ATTRIBUTE_CONTENT_SIZEAMSI_ATTRIBUTE_SESSIONAMSI_ATTRIBUTE_CONTENT_ADDRESS を取得します。

template<typename T>
T GetFixedSizeAttribute(_In_ IAmsiStream* stream, _In_ AMSI_ATTRIBUTE attribute)
{
    T result;

    ULONG actualSize;
    if (SUCCEEDED(stream->GetAttribute(attribute, sizeof(T), reinterpret_cast<PBYTE>(&result), &actualSize)) &&
        actualSize == sizeof(T))
    {
        return result;
    }
    return T();
}

コンテンツがメモリに完全に読み込まれている場合のメモリアドレスであるAMSI_ATTRIBUTE_CONTENT_ADDRESS の属性を取得した場合には、CalculateBufferXor 関数を使用してメモリコンテンツのすべてのバイトの XOR を計算し、ログ出力を行っています。

BYTE CalculateBufferXor(_In_ LPCBYTE buffer, _In_ ULONGLONG size)
{
    BYTE value = 0;
    for (ULONGLONG i = 0; i < size; i++)
    {
        value ^= buffer[i];
    }
    return value;
}
if (contentAddress)
{
    // The data to scan is provided in the form of a memory buffer.
    auto result = CalculateBufferXor(contentAddress, contentSize);
    TraceLoggingWrite(g_traceLoggingProvider, "Memory xor",
        TraceLoggingValue(requestNumber),
        TraceLoggingValue(result));
}

スキャン対象がファイルの場合には、Read メソッドを使用して 1024 バイトずつ取得したコンテンツに対して、同じく CalculateBufferXor 関数による計算を行っています。

参考:IAmsiStream::Read (amsi.h) - Win32 apps | Microsoft Learn

else
{
    // Provided as a stream. Read it stream a chunk at a time.
    BYTE cumulativeXor = 0;
    BYTE chunk[1024];
    ULONG readSize;
    for (ULONGLONG position = 0; position < contentSize; position += readSize)
    {
        HRESULT hr = stream->Read(position, sizeof(chunk), chunk, &readSize);
        if (SUCCEEDED(hr))
        {
            cumulativeXor ^= CalculateBufferXor(chunk, readSize);
            TraceLoggingWrite(g_traceLoggingProvider, "Read chunk",
                TraceLoggingValue(requestNumber),
                TraceLoggingValue(position),
                TraceLoggingValue(readSize),
                TraceLoggingValue(cumulativeXor));
        }
        else
        {
            TraceLoggingWrite(g_traceLoggingProvider, "Read failed",
                TraceLoggingValue(requestNumber),
                TraceLoggingValue(position),
                TraceLoggingValue(hr));
            break;
        }
    }
}

このように、このサンプルプロバイダーにおける Scan メソッドは非常にシンプルな実装でした。

ETW トレースログの取得と解析

このサンプルプロバイダーでは、GUID {00604c86-2d25-46d6-b814-cd149bfdf0b3} にて ETW トレースログの出力を行っています。

ETW トレースログの取得は、logman や traceview.exe など様々なツールで行うことができます。

Windows WDK に含まれる traceview.exe を使用すると、GUI で簡単に ETW トレースログの参照や ETL ファイルとしての保存を行うことができるので便利です。

WDK をインストール済みの場合、C:\Program Files (x86)\Windows Kits\10\bin\<バージョン>\x64\traceview.exe などのパスから traceview.exe を開くことができます。

image-20250705232855416

また、logman や traceview.exe などのツールを使用して ETW トレースログを ETL ファイルとして保存済みの場合、netsh を使用して以下のコマンドでテキストファイルに変換して参照することもできます。

netsh trace convert input=logging.etl output=decode.txt

traceview.exe はイベントデータのフィルタリングや検索には不向きなので、特定のイベントを詳しく解析したい場合にはテキストファイルにデコードすると良いと思います。

まとめ

今回は AMSI プロバイダーのサンプルコードを読んでみました。

もっと詳しい動作のデバッグなどはまた今後の記事で書こうと思います。