Mirror of https://github.com/roostorg/coop
github.com/roostorg/coop
1import { SpanStatusCode } from '@opentelemetry/api';
2import _ from 'lodash';
3import stringify from 'safe-stable-stringify';
4import { type ReadonlyDeep } from 'type-fest';
5
6import { inject } from '../../iocContainer/utils.js';
7import { type PolicyActionPenalties } from '../../models/OrgModel.js';
8import { type MatchingValues } from '../../models/rules/matchingValues.js';
9import { type LocationArea } from '../../models/types/locationArea.js';
10import { jsonStringify } from '../../utils/encoding.js';
11import { CoopError, ErrorType } from '../../utils/errors.js';
12import type SafeTracer from '../../utils/SafeTracer.js';
13import {
14 isNonEmptyArray,
15 type NonEmptyArray,
16} from '../../utils/typescript-types.js';
17import {
18 type SignalId,
19 type SignalInput,
20 type SignalInputType,
21 type SignalOutputType,
22 type SignalResult,
23 type SignalsService,
24 type SignalType,
25 type SignalTypesToRunInputTypes,
26} from '../signalsService/index.js';
27import { type HashBank } from '../hmaService/index.js';
28
29const { memoize } = _;
30
31// The exposed runSignal function takes the matchingValues as they're specified
32// _in the condition_ (which is w/ reference to a text/media bank id) and
33// resolves that id to the actual matchingValues that should be passed to the
34// signal.
35//
36// NB: when we support custom signals, pg id of the signal should be
37// an argument here too.
38type RunSignalInput<T extends SignalType = SignalType> = Omit<
39 SignalTypesToRunInputTypes[T],
40 'matchingValues' | 'actionPenalties'
41> & {
42 matchingValues?: ReadonlyDeep<MatchingValues>;
43 signal: SignalId;
44};
45
46export type TransientRunSignalWithCache = (
47 input: RunSignalInput,
48) => Promise<
49 SignalResult<SignalOutputType> | { type: 'ERROR'; score: unknown }
50>;
51
52type LocationsLoader = (
53 locationBankId: string,
54) => Promise<ReadonlyDeep<LocationArea[]> | undefined>;
55
56export default inject(
57 [
58 'getLocationBankLocationsEventuallyConsistent',
59 'getTextBankStringsEventuallyConsistent',
60 'getPolicyActionPenaltiesEventuallyConsistent',
61 'getImageBankEventuallyConsistent',
62 'SignalsService',
63 'Tracer',
64 ],
65 (
66 locationsLoader: LocationsLoader,
67 textBankStringsLoader: (input: {
68 orgId: string;
69 bankId: string;
70 }) => Promise<readonly string[] | undefined>,
71 getPolicyActionPenalties: (orgId: string) => Promise<ReadonlyDeep<PolicyActionPenalties[]>>,
72 getImageBank: (input: { orgId: string; bankId: string }) => Promise<HashBank | null>,
73 signalsService: SignalsService,
74 tracer: SafeTracer,
75 ) =>
76 /**
77 * Returns a function that can run signals. This function takes care of caching
78 * (in case the signal is run multiple times on the same content submission,
79 * likely for different rules); bulk loading the inputs for signals (like text
80 * bank matching values); and will eventually handle retries.
81 *
82 * The returned function is meant to be transient/ephemeral -- i.e., used for a
83 * single ruleSet execution or a single content submission. Keeping this
84 * function around long-term risks stale data (from its cache) being passed to
85 * signals, like text/location bank data or action penalties.
86 *
87 * Because signals themselves don't handle retries (they just throw on an error),
88 * a signal.run() call can never actually produce a ConditionFailureOutcome.
89 * However, this function can, if a signal's retry budget is exhausted or the
90 * error is identified as one that can't be retried, hence the return type.
91 */
92 function getTransientRunSignalWithCache(): TransientRunSignalWithCache {
93 const loadActionPenalties = memoize(getPolicyActionPenalties);
94
95 const textBanksStringsLoader = async (
96 orgId: string,
97 textBankIds: readonly string[],
98 ) =>
99 Promise.all(
100 textBankIds.map(async (id) =>
101 textBankStringsLoader({ orgId, bankId: id }),
102 ),
103 ).then(
104 (bankResults) =>
105 bankResults
106 .filter((it): it is readonly string[] => Array.isArray(it))
107 .flat() as readonly string[],
108 );
109
110 const imageBanksLoader = async (
111 orgId: string,
112 bankIds: readonly string[],
113 ) =>
114 Promise.all(
115 bankIds.map(async (bankId) =>
116 getImageBank({ orgId, bankId }),
117 ),
118 ).then(
119 (bankResults) =>
120 bankResults.filter((it): it is HashBank => it !== null),
121 );
122
123 // For running a signal for now with caching, we use a memoized function
124 // but _do not_ use dataloader, because we can't actually run signals in a
125 // batch, so dataloader's batching functionality will only serve to slow
126 // things down (blocking all our results on the slowest one in the batch).
127 //
128 // NB: this method of generating the cache key may need to be updated when
129 // we support custom signals, since they'll all use the same SignalType.
130 return memoize(
131 async (input: RunSignalInput) => {
132 return runSignal(
133 signalsService,
134 locationsLoader,
135 textBanksStringsLoader,
136 imageBanksLoader,
137 loadActionPenalties,
138 tracer,
139 input,
140 );
141 },
142 ({ signal, ...signalInput }) => stringify([signal, signalInput]),
143 );
144 },
145);
146
147async function runSignal(
148 signalsService: SignalsService,
149 locationsLoader: LocationsLoader,
150 textBanksStringsLoader: (
151 orgId: string,
152 bankIds: readonly string[],
153 ) => Promise<readonly string[]>,
154 imageBanksLoader: (
155 orgId: string,
156 bankIds: readonly string[],
157 ) => Promise<HashBank[]>,
158 actionPenaltiesLoader: (
159 orgId: string,
160 ) => Promise<ReadonlyDeep<PolicyActionPenalties[]>>,
161 tracer: SafeTracer,
162 signalInput: RunSignalInput,
163) {
164 return tracer.addActiveSpan(
165 {
166 operation: 'runSignal',
167 resource: signalInput.signal.type,
168 attributes: {
169 signal: jsonStringify(signalInput.signal),
170 orgId: signalInput.orgId,
171 contentId: signalInput.contentId ?? '',
172 },
173 },
174 async (span) => {
175 try {
176 const { orgId, signal: signalId, matchingValues } = signalInput;
177 // eslint-disable-next-line @typescript-eslint/no-explicit-any
178 const signalRef = { orgId, signalId } as any;
179 const signal = await signalsService.getSignalOrThrow(signalRef);
180
181 const [finalMatchingValues, finalActionPenalties] = await Promise.all([
182 (async () => {
183 if (!signal.needsMatchingValues) {
184 return undefined;
185 }
186
187 const { locationBankIds, textBankIds, imageBankIds } = matchingValues ?? {};
188
189 // A condition can have both strings and text banks as matching
190 // values simultaneously (and same with locations and location
191 // banks). So we first fetch all the "scalar" values (plain strings
192 // & locations), then extract the "scalar" values stored in text &
193 // location banks.
194 const scalarMatchingValues =
195 matchingValues?.strings ?? matchingValues?.locations ?? [];
196
197 let matchingValuesFromBanks: readonly (string | ReadonlyDeep<LocationArea> | HashBank)[] = [];
198 if (textBankIds?.length) {
199 matchingValuesFromBanks = await textBanksStringsLoader(orgId, textBankIds);
200 } else if (locationBankIds?.length) {
201 matchingValuesFromBanks = await Promise.all(
202 locationBankIds.map(async (id) => locationsLoader(id)),
203 ).then((allBankLocations) =>
204 allBankLocations
205 .filter((it): it is ReadonlyDeep<LocationArea>[] =>
206 Array.isArray(it),
207 )
208 .flat(),
209 );
210 } else if (imageBankIds?.length) {
211 matchingValuesFromBanks = await imageBanksLoader(orgId, imageBankIds);
212 }
213
214 const loadedMatchingValues = [
215 ...scalarMatchingValues,
216 ...matchingValuesFromBanks,
217 ];
218
219 if (!isNonEmptyArray(loadedMatchingValues)) {
220 throw new CoopError({
221 status: 400,
222 name: 'CoopError',
223 type: [ErrorType.InvalidMatchingValues],
224 shouldErrorSpan: true,
225 title:
226 'Matching values were required, but none were found, or bank was empty.',
227 });
228 }
229
230 return loadedMatchingValues satisfies NonEmptyArray<
231 string | ReadonlyDeep<LocationArea> | HashBank
232 > as
233 | NonEmptyArray<string>
234 | NonEmptyArray<ReadonlyDeep<LocationArea>>
235 | NonEmptyArray<HashBank>;
236 })(),
237 signal.needsActionPenalties
238 ? actionPenaltiesLoader(signalInput.orgId)
239 : undefined,
240 ]);
241
242 const fullInput = {
243 ...signalInput,
244 actionPenalties: finalActionPenalties,
245 matchingValues: finalMatchingValues,
246 } as unknown as SignalInput<SignalInputType>;
247
248 return await signalsService.runSignal({
249 signal: signalRef,
250 input: fullInput,
251 });
252 } catch (e) {
253 if (e instanceof Error) {
254 span.recordException(e);
255 }
256 span.setStatus({ code: SpanStatusCode.ERROR });
257 return { type: 'ERROR' as const, score: e };
258 }
259 },
260 );
261}