diff --git a/api/src/controllers/seed/insert/entries.ts b/api/src/controllers/seed/insert/entries.ts index c53590c4..fb5c9e60 100644 --- a/api/src/controllers/seed/insert/entries.ts +++ b/api/src/controllers/seed/insert/entries.ts @@ -6,7 +6,7 @@ import { entryVideoJoin, videos, } from "~/db/schema"; -import { conflictUpdateAllExcept, values } from "~/db/utils"; +import { conflictUpdateAllExcept, unnestValues, values } from "~/db/utils"; import type { SeedEntry as SEntry, SeedExtra as SExtra } from "~/models/entry"; import { enqueueOptImage, flushImageQueue, type ImageTask } from "../images"; import { guessNextRefresh } from "../refresh"; @@ -75,7 +75,7 @@ export const insertEntries = async ( }); const ret = await tx .insert(entries) - .values(vals) + .select(unnestValues(vals, entries)) .onConflictDoUpdate({ target: entries.slug, set: conflictUpdateAllExcept(entries, [ @@ -120,7 +120,7 @@ export const insertEntries = async ( await flushImageQueue(tx, imgQueue, 0); await tx .insert(entryTranslations) - .values(trans) + .select(unnestValues(trans, entryTranslations)) .onConflictDoUpdate({ target: [entryTranslations.pk, entryTranslations.language], set: conflictUpdateAllExcept(entryTranslations, ["pk", "language"]), diff --git a/api/src/db/index.ts b/api/src/db/index.ts index cc4a8901..c7132116 100644 --- a/api/src/db/index.ts +++ b/api/src/db/index.ts @@ -1,13 +1,13 @@ import os from "node:os"; import path from "node:path"; import tls, { type ConnectionOptions } from "node:tls"; +import { record } from "@elysiajs/opentelemetry"; import { instrumentDrizzleClient } from "@kubiks/otel-drizzle"; import { sql } from "drizzle-orm"; import { drizzle } from "drizzle-orm/node-postgres"; import { migrate as migrateDb } from "drizzle-orm/node-postgres/migrator"; import type { PoolConfig } from "pg"; import * as schema from "./schema"; -import { record } from "@elysiajs/opentelemetry"; const config: PoolConfig = { connectionString: process.env.POSTGRES_URL, diff --git a/api/src/db/utils.ts b/api/src/db/utils.ts index bab0eedf..e76423ed 100644 --- a/api/src/db/utils.ts +++ b/api/src/db/utils.ts @@ -8,12 +8,17 @@ import { type Subquery, sql, Table, + type TableConfig, View, ViewBaseConfig, } from "drizzle-orm"; import type { CasingCache } from "drizzle-orm/casing"; import type { AnyMySqlSelect } from "drizzle-orm/mysql-core"; -import type { AnyPgSelect, SelectedFieldsFlat } from "drizzle-orm/pg-core"; +import type { + AnyPgSelect, + PgTableWithColumns, + SelectedFieldsFlat, +} from "drizzle-orm/pg-core"; import type { AnySQLiteSelect } from "drizzle-orm/sqlite-core"; import type { WithSubquery } from "drizzle-orm/subquery"; import { db } from "./index"; @@ -70,7 +75,15 @@ export function conflictUpdateAllExcept< // drizzle is bugged and doesn't allow js arrays to be used in raw sql. export function sqlarr(array: unknown[]) { - return `{${array.map((item) => `"${item}"`).join(",")}}`; + return `{${array + .map((item) => + !item || item === "null" + ? "null" + : typeof item === "object" + ? `"${JSON.stringify(item).replaceAll('"', '\\"')}"` + : `"${item}"`, + ) + .join(", ")}}`; } // See https://github.com/drizzle-team/drizzle-orm/issues/4044 @@ -103,6 +116,75 @@ export function values( }; } +/* goal: + * unnestValues([{a: 1, b: 2}, {a: 3, b: 4}], tbl) + * + * ```sql + * select a, b, now() as updated_at from unnest($1::integer[], $2::integer[]); + * ``` + * params: + * $1: [1, 2] + * $2: [3, 4] + * + * select + */ +export const unnestValues = < + T extends Record, + C extends TableConfig = never, +>( + values: T[], + typeInfo: PgTableWithColumns, +) => { + if (values[0] === undefined) + throw new Error("Invalid values, expecting at least one items"); + const columns = getTableColumns(typeInfo); + const keys = Object.keys(values[0]).filter((x) => x in columns); + // @ts-expect-error: drizzle internal + const casing = db.dialect.casing as CasingCache; + const dbNames = Object.fromEntries( + Object.entries(columns).map(([k, v]) => [k, casing.getColumnCasing(v)]), + ); + const vals = values.reduce( + (acc, cur, i) => { + for (const k of keys) { + if (k in cur) acc[k].push(cur[k]); + else acc[k].push(null); + } + for (const k of Object.keys(cur)) { + if (k in acc) continue; + if (!(k in columns)) continue; + keys.push(k); + acc[k] = new Array(i).fill(null); + acc[k].push(cur[k]) + } + return acc; + }, + Object.fromEntries(keys.map((x) => [x, [] as unknown[]])), + ); + const computed = Object.entries(columns) + .filter(([k, v]) => (v.defaultFn || v.onUpdateFn) && !keys.includes(k)) + .map(([k]) => k); + return db + .select( + Object.fromEntries([ + ...keys.map((x) => [x, sql.raw(`"${dbNames[x]}"`)]), + ...computed.map((x) => [ + x, + (columns[x].defaultFn?.() ?? columns[x].onUpdateFn!()).as(dbNames[x]), + ]), + ]), + ) + .from( + sql`unnest(${sql.join( + keys.map( + (k) => + sql`${sqlarr(vals[k])}${sql.raw(`::${columns[k].getSQLType()}[]`)}`, + ), + sql.raw(", "), + )}) as v(${sql.raw(keys.map((x) => `"${dbNames[x]}"`).join(", "))})`, + ); +}; + export const coalesce = (val: SQL | SQLWrapper, def: SQL | Column) => { return sql`coalesce(${val}, ${def})`; }; diff --git a/api/tests/misc/images.test.ts b/api/tests/misc/images.test.ts index db5a3212..acaf6ff6 100644 --- a/api/tests/misc/images.test.ts +++ b/api/tests/misc/images.test.ts @@ -21,6 +21,7 @@ describe("images", () => { const release = await processImages(); // remove notifications to prevent other images to be downloaded (do not curl 20000 images for nothing) release(); + await db.delete(mqueue); const ret = await db.query.shows.findFirst({ where: eq(shows.slug, madeInAbyss.slug), @@ -45,6 +46,7 @@ describe("images", () => { const release = await processImages(); // remove notifications to prevent other images to be downloaded (do not curl 20000 images for nothing) release(); + await db.delete(mqueue); const failed = await db.query.mqueue.findFirst({ where: and( diff --git a/api/tests/setup.ts b/api/tests/setup.ts index 75382825..873c8192 100644 --- a/api/tests/setup.ts +++ b/api/tests/setup.ts @@ -3,6 +3,7 @@ import { beforeAll } from "bun:test"; process.env.PGDATABASE = "kyoo_test"; process.env.JWT_SECRET = "this is a secret"; process.env.JWT_ISSUER = "https://kyoo.zoriya.dev"; +process.env.IMAGES_PATH = "./images"; beforeAll(async () => { // lazy load this so env set before actually applies