Feat: Add image generation streaming support

This commit is contained in:
2025-11-15 18:54:53 -08:00
parent 71b38dda0c
commit 37a53e17de

View File

@@ -5,7 +5,7 @@ const TTL_MS = 20 * 60 * 1000;
const BATCH_MS = 800; const BATCH_MS = 800;
const BATCH_BYTES = 3400; const BATCH_BYTES = 3400;
const HB_INTERVAL_MS = 3000; const HB_INTERVAL_MS = 3000;
const MAX_RUN_MS = 10 * 60 * 1000; const MAX_RUN_MS = 7 * 60 * 1000;
const CORS_HEADERS = { const CORS_HEADERS = {
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
@@ -86,7 +86,7 @@ export class MyDurableObject {
getConversationText() { getConversationText() {
const prompt = (this.messages || []).map(m => `## ${m.role}\n\n${this.extractTextFromMessage(m)}`).join('\n\n---\n\n'); const prompt = (this.messages || []).map(m => `## ${m.role}\n\n${this.extractTextFromMessage(m)}`).join('\n\n---\n\n');
const response = this.buffer.map(it => it.text || '').join(''); const response = this.buffer.map(it => it.text).join('');
if (!prompt && !response) return ''; if (!prompt && !response) return '';
return `${prompt}\n\n---\n\n## assistant\n\n${response}`; return `${prompt}\n\n---\n\n## assistant\n\n${response}`;
} }
@@ -151,12 +151,10 @@ export class MyDurableObject {
flush(force = false) { flush(force = false) {
if (this.flushTimer) { clearTimeout(this.flushTimer); this.flushTimer = null; } if (this.flushTimer) { clearTimeout(this.flushTimer); this.flushTimer = null; }
if (this.pending || this.pendingImages.length > 0) { if (this.pending || this.pendingImages.length) {
const payload = { type: 'delta', seq: ++this.seq }; const imgs = this.pendingImages.length ? [...this.pendingImages] : undefined;
if (this.pending) payload.text = this.pending; this.buffer.push({ seq: ++this.seq, text: this.pending, images: imgs });
if (this.pendingImages.length > 0) payload.images = this.pendingImages; this.bcast({ type: 'delta', seq: this.seq, text: this.pending, images: imgs });
this.buffer.push(payload);
this.bcast(payload);
this.pending = ''; this.pending = '';
this.pendingImages = []; this.pendingImages = [];
this.lastFlushedAt = Date.now(); this.lastFlushedAt = Date.now();
@@ -164,19 +162,13 @@ export class MyDurableObject {
if (force) this.saveSnapshot(); if (force) this.saveSnapshot();
} }
queueDelta(text) { queueDelta(text, images) {
if (!text) return; if (text) this.pending += text;
this.pending += text; if (images && images.length) this.pendingImages.push(...images);
if (this.pending.length >= BATCH_BYTES) this.flush(false); if (this.pending.length >= BATCH_BYTES || this.pendingImages.length) this.flush(false);
else if (!this.flushTimer) this.flushTimer = setTimeout(() => this.flush(false), BATCH_MS); else if (!this.flushTimer) this.flushTimer = setTimeout(() => this.flush(false), BATCH_MS);
} }
queueImages(images) {
if (!Array.isArray(images) || images.length === 0) return;
this.pendingImages.push(...images);
if (!this.flushTimer) this.flushTimer = setTimeout(() => this.flush(false), BATCH_MS);
}
async fetch(req) { async fetch(req) {
if (req.method === 'OPTIONS') return new Response(null, { status: 204, headers: CORS_HEADERS }); if (req.method === 'OPTIONS') return new Response(null, { status: 204, headers: CORS_HEADERS });
@@ -191,12 +183,10 @@ export class MyDurableObject {
if (req.method === 'GET') { if (req.method === 'GET') {
await this.autopsy(); await this.autopsy();
const text = this.buffer.map(it => it.text || '').join('') + this.pending; const text = this.buffer.map(it => it.text).join('') + this.pending;
const images = this.buffer.flatMap(it => it.images || []);
if (this.pendingImages.length > 0) images.push(...this.pendingImages);
const isTerminal = ['done', 'error', 'evicted'].includes(this.phase); const isTerminal = ['done', 'error', 'evicted'].includes(this.phase);
const isError = ['error', 'evicted'].includes(this.phase); const isError = ['error', 'evicted'].includes(this.phase);
const payload = { rid: this.rid, seq: this.seq, phase: this.phase, done: isTerminal, error: isError ? (this.error || 'The run was terminated unexpectedly.') : null, text, images }; const payload = { rid: this.rid, seq: this.seq, phase: this.phase, done: isTerminal, error: isError ? (this.error || 'The run was terminated unexpectedly.') : null, text };
return this.corsJSON(payload); return this.corsJSON(payload);
} }
return this.corsJSON({ error: 'not allowed' }, 405); return this.corsJSON({ error: 'not allowed' }, 405);
@@ -356,8 +346,9 @@ export class MyDurableObject {
this.queueDelta(delta.content); this.queueDelta(delta.content);
hasContent = true; hasContent = true;
} }
if (delta?.images) { const msg = chunk?.choices?.[0]?.message;
this.queueImages(delta.images); if (msg?.images && msg.images.length) {
this.queueDelta('', msg.images.map(img => ({ type: 'image_url', image_url: { url: img.image_url?.url || img.image_url } })));
} }
} }
} }
@@ -370,7 +361,8 @@ export class MyDurableObject {
try { this.controller?.abort(); } catch {} try { this.controller?.abort(); } catch {}
try { this.oaStream?.controller?.abort(); } catch {} try { this.oaStream?.controller?.abort(); } catch {}
this.saveSnapshot(); this.saveSnapshot();
this.bcast({ type: 'done' }); const finalImages = this.buffer.flatMap(b => b.images || []);
this.bcast({ type: 'done', images: finalImages.length ? finalImages : undefined });
this.state.waitUntil(this.stopHeartbeat()); this.state.waitUntil(this.stopHeartbeat());
} }
@@ -424,7 +416,7 @@ export class MyDurableObject {
mapContentPartToResponses(part) { mapContentPartToResponses(part) {
const type = part?.type || 'text'; const type = part?.type || 'text';
if (['image_url', 'input_image'].includes(type)) return (part?.image_url?.url || part?.image_url) ? { type: 'input_image', image_url: String(part?.image_url?.url || part?.image_url) } : null; if (['image_url', 'input_image'].includes(type)) return (part?.image_url?.url || part?.image_url) ? { type: 'input_image', image_url: String(part?.image_url?.url || part?.image_url) } : null;
if (['text', 'input_text'].includes(type)) return { type: 'input_text', text: String(type === 'text' ? (p.text ?? p.content ?? '') : (p.text ?? '')) }; if (['text', 'input_text'].includes(type)) return { type: 'input_text', text: String(type === 'text' ? (part.text ?? part.content ?? '') : (part.text ?? '')) };
return { type: 'input_text', text: `[${type}:${part?.file?.filename || 'file'}]` }; return { type: 'input_text', text: `[${type}:${part?.file?.filename || 'file'}]` };
} }