AKARI Tech Blog

燈株式会社のエンジニア・開発メンバーによる技術ブログです

DeepSeek-V3(FP8量子化)を自社サーバーのH200 8枚で動かしてみた

こんにちは、DX Solution 事業本部 VPoEの丸尾です。

先日、CTO三澤もツイートしておりましたが、燈の遠隔データセンターにてH200を8枚搭載した強力なサーバーを導入しました 🎉🎉

圧倒的なGPUメモリ量ですね! せっかくなので何か最新モデルの推論を動かしたいと思い、DeepSeek-V3(FP8量子化)を試してみました!

DeepSeek-V3 (FP8量子化)を動かしてみる

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config

# transformersのバージョンとモデルの互換性がないためのモンキーパッチ
from transformers.models.llama.modeling_llama import DynamicCache
DynamicCache.get_max_length = lambda self: None

model_name = "deepseek-ai/DeepSeek-V3"

# FP8 設定を有効化
quant_config = FineGrainedFP8Config()

device = "cuda" if torch.cuda.is_available() else "cpu"

# モデルをロード
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    quantization_config=quant_config,
    device_map="auto",
    trust_remote_code=True,
)
model.eval()

# トークナイザーのロード
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 推論用プロンプト
prompt = "むかしむかし、あるところに"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# 推論の実行
with torch.inference_mode():
    outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)

# 結果の表示
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Prompt:", prompt)
print("Completion:", generated_text)

こちらを実行すると、まず約600GBのモデルファイルのダウンロードが始まりました。 すぐにモデルを動かすのは難しいので、深夜のうちにダウンロードしておきましょう。

夜が明けると

Prompt: むかしむかし、あるところに
>>> print("Completion:", generated_text)
Completion: むかしむかし、あるところに、おじいさんとおばあさんが住んでいました。

おじいさんは山へ柴刈りに、おばあさんは川へ洗濯に行きました。

翌朝確認すると、無事に出力されていました 🎉🎉

nvidia-smi でH200のGPUが並列で動作している様子を見るのはさすがに圧巻ですね。

さらに、ngrokでサーバーを公開し、Slack Botとして使えるようにしましょう。

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config
from transformers.models.llama.modeling_llama import DynamicCache
DynamicCache.get_max_length = lambda self: None

import json
import urllib.request
import urllib.parse
from http.server import BaseHTTPRequestHandler, HTTPServer
import threading

# モデルの読み込み設定
MODEL_DIR = "/models/deepseek-ai/DeepSeek-V3"
if not os.path.exists(MODEL_DIR):
    print("モデルが見つかりません。ダウンロードを開始します...")
    model_name = "deepseek-ai/DeepSeek-V3"
else:
    print(f"既存のモデルを使用します: {MODEL_DIR}")
    model_name = MODEL_DIR

quant_config = FineGrainedFP8Config()
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    quantization_config=quant_config,
    device_map="auto",
    trust_remote_code=True,
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

# バックグラウンド処理でモデル生成を実行し、結果を response_url に送信する関数
def background_generate(prompt, response_url):
    print("バックグラウンドで生成処理開始:", prompt)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.inference_mode():
        outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # response_url を用いて生成結果を送信
    payload = {
        "response_type": "in_channel",
        "text": generated_text
    }
    data = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(url=response_url, data=data, headers={"Content-Type": "application/json"}, method="POST")
    try:
        with urllib.request.urlopen(req) as response:
            result = response.read().decode("utf-8")
            print("response_urlへの送信結果:", result)
    except Exception as e:
        print("response_urlへの送信エラー:", e)

# Slack の Slash Command 用 HTTP ハンドラー
class SlackCommandHandler(BaseHTTPRequestHandler):
    def _set_response(self, status=200):
        self.send_response(status)
        self.send_header("Content-type", "application/json")
        self.end_headers()

    def do_POST(self):
        # POSTデータの長さを取得して読み込む
        content_length = int(self.headers.get("Content-Length", 0))
        post_data = self.rfile.read(content_length)

        # Slash Command は application/x-www-form-urlencoded 形式で送信されるため解析する
        parsed_data = urllib.parse.parse_qs(post_data.decode("utf-8"))
        command_text = parsed_data.get("text", [""])[0]
        response_url = parsed_data.get("response_url", [""])[0]

        user_id = parsed_data.get("user_id", [""])[0]
        user_name = parsed_data.get("user_name", [""])[0]

        print("受信したコマンドテキスト:", command_text)
        print("response_url:", response_url)
        print("ユーザーID:", user_id)
        print("ユーザー名:", user_name)

        # Slack にはすぐに受付完了の応答を返す
        ack_payload = {
            "response_type": "in_channel", 
            "text": f"Command from: {user_name} (ID: {user_id})\nprompt: {command_text}",
        }
        self._set_response(200)
        self.wfile.write(json.dumps(ack_payload).encode("utf-8"))

        # バックグラウンドスレッドで生成処理を実行
        prompt = command_text if command_text else "むかしむかし、あるところに"
        thread = threading.Thread(target=background_generate, args=(prompt, response_url))
        thread.start()

def run(server_class=HTTPServer, handler_class=SlackCommandHandler, port=8080):
    server_address = ("", port)
    httpd = server_class(server_address, handler_class)
    print(f"ポート {port} でサーバー開始...")
    httpd.serve_forever()

if __name__ == "__main__":
    run()

みんなで遊んでみた

ストレスのない推論速度で快適に動作しています。 短い出力で意図通りに回答させたい場合、「ズバリ」と入力すると結論から答えてくれる、というハックが社内で生まれていました。

We're Hiring!

燈では、自社のGPUサーバーを活用してLLMのモデルを動かしています。 大規模な学習に興味のある方、ぜひカジュアル面談でお話ししましょう!

akariinc.co.jp