001/*
002 * SPDX-FileCopyrightText: none
003 * SPDX-License-Identifier: CC0-1.0
004 */
005
006package gov.nist.secauto.oscal.lib.profile.resolver.alter;
007
008import gov.nist.secauto.metaschema.core.datatype.markup.MarkupLine;
009import gov.nist.secauto.metaschema.core.util.CollectionUtil;
010import gov.nist.secauto.metaschema.core.util.CustomCollectors;
011import gov.nist.secauto.metaschema.core.util.ObjectUtils;
012import gov.nist.secauto.oscal.lib.model.Catalog;
013import gov.nist.secauto.oscal.lib.model.CatalogGroup;
014import gov.nist.secauto.oscal.lib.model.Control;
015import gov.nist.secauto.oscal.lib.model.ControlPart;
016import gov.nist.secauto.oscal.lib.model.Link;
017import gov.nist.secauto.oscal.lib.model.Parameter;
018import gov.nist.secauto.oscal.lib.model.Property;
019import gov.nist.secauto.oscal.lib.model.control.catalog.ICatalogVisitor;
020import gov.nist.secauto.oscal.lib.profile.resolver.ProfileResolutionEvaluationException;
021
022import java.util.Collections;
023import java.util.EnumMap;
024import java.util.EnumSet;
025import java.util.LinkedList;
026import java.util.List;
027import java.util.ListIterator;
028import java.util.Locale;
029import java.util.Map;
030import java.util.Set;
031import java.util.concurrent.ConcurrentHashMap;
032import java.util.function.Consumer;
033import java.util.function.Function;
034import java.util.function.Supplier;
035
036import edu.umd.cs.findbugs.annotations.NonNull;
037import edu.umd.cs.findbugs.annotations.Nullable;
038
039public class AddVisitor implements ICatalogVisitor<Boolean, AddVisitor.Context> {
040  public enum TargetType {
041    CONTROL("control", Control.class),
042    PARAM("param", Parameter.class),
043    PART("part", ControlPart.class);
044
045    @NonNull
046    private static final Map<Class<?>, TargetType> CLASS_TO_TYPE;
047    @NonNull
048    private static final Map<String, TargetType> NAME_TO_TYPE;
049    @NonNull
050    private final String fieldName;
051    @NonNull
052    private final Class<?> clazz;
053
054    static {
055      {
056        Map<Class<?>, TargetType> map = new ConcurrentHashMap<>();
057        for (TargetType type : values()) {
058          map.put(type.getClazz(), type);
059        }
060        CLASS_TO_TYPE = CollectionUtil.unmodifiableMap(map);
061      }
062
063      {
064        Map<String, TargetType> map = new ConcurrentHashMap<>();
065        for (TargetType type : values()) {
066          map.put(type.fieldName(), type);
067        }
068        NAME_TO_TYPE = CollectionUtil.unmodifiableMap(map);
069      }
070    }
071
072    /**
073     * Get the target type associated with the provided {@code clazz}.
074     *
075     * @param clazz
076     *          the class to identify the target type for
077     * @return the associated target type or {@code null} if the class is not
078     *         associated with a target type
079     */
080    @Nullable
081    public static TargetType forClass(@NonNull Class<?> clazz) {
082      Class<?> target = clazz;
083      TargetType retval;
084      // recurse over parent classes to find a match
085      do {
086        retval = CLASS_TO_TYPE.get(target);
087      } while (retval == null && (target = target.getSuperclass()) != null);
088      return retval;
089    }
090
091    /**
092     * Get the target type associated with the provided field {@code name}.
093     *
094     * @param name
095     *          the field name to identify the target type for
096     * @return the associated target type or {@code null} if the name is not
097     *         associated with a target type
098     */
099    @Nullable
100    public static TargetType forFieldName(@Nullable String name) {
101      return name == null ? null : NAME_TO_TYPE.get(name);
102    }
103
104    TargetType(@NonNull String fieldName, @NonNull Class<?> clazz) {
105      this.fieldName = fieldName;
106      this.clazz = clazz;
107    }
108
109    /**
110     * Get the field name associated with the target type.
111     *
112     * @return the name
113     */
114    public String fieldName() {
115      return fieldName;
116    }
117
118    /**
119     * Get the bound class associated with the target type.
120     *
121     * @return the class
122     */
123    public Class<?> getClazz() {
124      return clazz;
125    }
126  }
127
128  public enum Position {
129    BEFORE,
130    AFTER,
131    STARTING,
132    ENDING;
133
134    @NonNull
135    private static final Map<String, Position> NAME_TO_POSITION;
136
137    static {
138      Map<String, Position> map = new ConcurrentHashMap<>();
139      for (Position position : values()) {
140        map.put(position.name().toLowerCase(Locale.ROOT), position);
141      }
142      NAME_TO_POSITION = CollectionUtil.unmodifiableMap(map);
143    }
144
145    /**
146     * Get the position associated with the provided {@code name}.
147     *
148     * @param name
149     *          the name to identify the position for
150     * @return the associated position or {@code null} if the name is not associated
151     *         with a position
152     */
153    @Nullable
154    public static Position forName(@Nullable String name) {
155      return name == null ? null : NAME_TO_POSITION.get(name);
156    }
157  }
158
159  @NonNull
160  private static final AddVisitor INSTANCE = new AddVisitor();
161  private static final Map<TargetType, Set<TargetType>> APPLICABLE_TARGETS;
162
163  static {
164    APPLICABLE_TARGETS = new EnumMap<>(TargetType.class);
165    APPLICABLE_TARGETS.put(TargetType.CONTROL, Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
166    APPLICABLE_TARGETS.put(TargetType.PARAM, Set.of(TargetType.PARAM));
167    APPLICABLE_TARGETS.put(TargetType.PART, Set.of(TargetType.PART));
168  }
169
170  private static Set<TargetType> getApplicableTypes(@NonNull TargetType type) {
171    return APPLICABLE_TARGETS.getOrDefault(type, CollectionUtil.emptySet());
172  }
173
174  /**
175   * Apply the add directive.
176   *
177   * @param control
178   *          the control target
179   * @param position
180   *          the position to apply the content or {@code null}
181   * @param byId
182   *          the identifier of the target or {@code null}
183   * @param title
184   *          a title to set
185   * @param params
186   *          parameters to add
187   * @param props
188   *          properties to add
189   * @param links
190   *          links to add
191   * @param parts
192   *          parts to add
193   * @return {@code true} if the modification was made or {@code false} otherwise
194   * @throws ProfileResolutionEvaluationException
195   *           if a processing error occurred during profile resolution
196   */
197  public static boolean add(
198      @NonNull Control control,
199      @Nullable Position position,
200      @Nullable String byId,
201      @Nullable MarkupLine title,
202      @NonNull List<Parameter> params,
203      @NonNull List<Property> props,
204      @NonNull List<Link> links,
205      @NonNull List<ControlPart> parts) {
206    return INSTANCE.visitControl(
207        control,
208        new Context(
209            control,
210            position == null ? Position.ENDING : position,
211            byId,
212            title,
213            params,
214            props,
215            links,
216            parts));
217  }
218
219  @Override
220  public Boolean visitCatalog(Catalog catalog, Context context) {
221    // not required
222    throw new UnsupportedOperationException("not needed");
223  }
224
225  @Override
226  public Boolean visitGroup(CatalogGroup group, Context context) {
227    // not required
228    throw new UnsupportedOperationException("not needed");
229  }
230
231  /**
232   * If the add applies to the current object, then apply the child objects.
233   * <p>
234   * An add applies if:
235   * <ol>
236   * <li>the {@code targetItem} supports all of the children</li>
237   * <li>the context matches if:
238   * <ul>
239   * <li>the target item's id matches the "by-id"; or</li>
240   * <li>the "by-id" is not defined and the target item is the control matching
241   * the target context</li>
242   * </ul>
243   * </li>
244   * </ol>
245   *
246   * @param <T>
247   *          the type of the {@code targetItem}
248   * @param targetItem
249   *          the current target to process
250   * @param titleConsumer
251   *          a consumer to apply a title to or {@code null} if the object has no
252   *          title field
253   * @param paramsSupplier
254   *          a supplier for the child {@link Parameter} collection
255   * @param propsSupplier
256   *          a supplier for the child {@link Property} collection
257   * @param linksSupplier
258   *          a supplier for the child {@link Link} collection
259   * @param partsSupplier
260   *          a supplier for the child {@link ControlPart} collection
261   * @param context
262   *          the add context
263   * @return {@code true} if a modification was made or {@code false} otherwise
264   */
265  private static <T> boolean handleCurrent(
266      @NonNull T targetItem,
267      @Nullable Consumer<MarkupLine> titleConsumer,
268      @Nullable Supplier<? extends List<Parameter>> paramsSupplier,
269      @Nullable Supplier<? extends List<Property>> propsSupplier,
270      @Nullable Supplier<? extends List<Link>> linksSupplier,
271      @Nullable Supplier<? extends List<ControlPart>> partsSupplier,
272      @NonNull Context context) {
273    boolean retval = false;
274    Position position = context.getPosition();
275    if (context.appliesTo(targetItem) && !context.isSequenceTargeted(targetItem)) {
276      // the target item is the target of the add
277      MarkupLine newTitle = context.getTitle();
278      if (newTitle != null) {
279        assert titleConsumer != null;
280        titleConsumer.accept(newTitle);
281      }
282
283      handleCollection(position, context.getParams(), paramsSupplier);
284      handleCollection(position, context.getProps(), propsSupplier);
285      handleCollection(position, context.getLinks(), linksSupplier);
286      handleCollection(position, context.getParts(), partsSupplier);
287      retval = true;
288    }
289    return retval;
290  }
291
292  private static <T> void handleCollection(
293      @NonNull Position position,
294      @NonNull List<T> newItems,
295      @Nullable Supplier<? extends List<T>> originalCollectionSupplier) {
296    if (originalCollectionSupplier != null) {
297      List<T> oldItems = originalCollectionSupplier.get();
298      if (!newItems.isEmpty()) {
299        if (Position.STARTING.equals(position)) {
300          oldItems.addAll(0, newItems);
301        } else { // ENDING
302          oldItems.addAll(newItems);
303        }
304      }
305    }
306  }
307
308  // private static <T> void handleChild(
309  // @NonNull TargetType itemType,
310  // @NonNull Supplier<? extends List<T>> collectionSupplier,
311  // @Nullable Consumer<T> handler,
312  // @NonNull Context context) {
313  // boolean handleChildren = !Collections.disjoint(context.getTargetItemTypes(),
314  // getApplicableTypes(itemType));
315  // if (handleChildren && handler != null) {
316  // // if the child item type is applicable and there is a handler, iterate over
317  // children
318  // Iterator<T> iter = collectionSupplier.get().iterator();
319  // while (iter.hasNext()) {
320  // T item = iter.next();
321  // if (item != null) {
322  // handler.accept(item);
323  // }
324  // }
325  // }
326  // }
327
328  private static <T> boolean handleChild(
329      @NonNull TargetType itemType,
330      @NonNull Supplier<? extends List<T>> originalCollectionSupplier,
331      @NonNull Supplier<? extends List<T>> newItemsSupplier,
332      @Nullable Function<T, Boolean> handler,
333      @NonNull Context context) {
334
335    // determine if this child type can match
336    boolean isItemTypeMatch = context.isMatchingType(itemType);
337
338    Set<TargetType> applicableTypes = getApplicableTypes(itemType);
339    boolean descendChild = handler != null && !Collections.disjoint(context.getTargetItemTypes(), applicableTypes);
340
341    boolean retval = false;
342    if (isItemTypeMatch || descendChild) {
343      // if the item type is applicable, attempt to match by id
344      List<T> collection = originalCollectionSupplier.get();
345      ListIterator<T> iter = collection.listIterator();
346      boolean deferred = false;
347      while (iter.hasNext()) {
348        T item = ObjectUtils.requireNonNull(iter.next());
349
350        if (isItemTypeMatch && context.appliesTo(item) && context.isSequenceTargeted(item)) {
351          // if id match, inject the new items into the collection
352          switch (context.getPosition()) {
353          case AFTER: {
354            newItemsSupplier.get().forEach(iter::add);
355            retval = true;
356            break;
357          }
358          case BEFORE: {
359            iter.previous();
360            List<T> adds = newItemsSupplier.get();
361            adds.forEach(iter::add);
362            item = iter.next();
363            retval = true;
364            break;
365          }
366          case STARTING:
367          case ENDING:
368            deferred = true;
369            break;
370          default:
371            throw new UnsupportedOperationException(context.getPosition().name().toLowerCase(Locale.ROOT));
372          }
373        }
374
375        if (descendChild) {
376          assert handler != null;
377
378          // handle child items since they are applicable to the search criteria
379          retval = retval || handler.apply(item);
380        }
381      }
382
383      if (deferred) {
384        List<T> newItems = newItemsSupplier.get();
385        if (Position.ENDING.equals(context.getPosition())) {
386          collection.addAll(newItems);
387          retval = true;
388        } else if (Position.STARTING.equals(context.getPosition())) {
389          collection.addAll(0, newItems);
390          retval = true;
391        }
392      }
393    }
394    return retval;
395  }
396
397  @Override
398  public Boolean visitControl(Control control, Context context) {
399    assert context != null;
400
401    if (control.getParams() == null) {
402      control.setParams(new LinkedList<>());
403    }
404
405    if (control.getProps() == null) {
406      control.setProps(new LinkedList<>());
407    }
408
409    if (control.getLinks() == null) {
410      control.setLinks(new LinkedList<>());
411    }
412
413    if (control.getParts() == null) {
414      control.setParts(new LinkedList<>());
415    }
416
417    boolean retval = handleCurrent(
418        control,
419        control::setTitle,
420        control::getParams,
421        control::getProps,
422        control::getLinks,
423        control::getParts,
424        context);
425
426    // visit params
427    retval = retval || handleChild(
428        TargetType.PARAM,
429        control::getParams,
430        context::getParams,
431        child -> visitParameter(ObjectUtils.notNull(child), context),
432        context);
433
434    // visit parts
435    retval = retval || handleChild(
436        TargetType.PART,
437        control::getParts,
438        context::getParts,
439        child -> visitPart(child, context),
440        context);
441
442    // visit control children
443    for (Control childControl : CollectionUtil.listOrEmpty(control.getControls())) {
444      Set<TargetType> applicableTypes = getApplicableTypes(TargetType.CONTROL);
445      if (!Collections.disjoint(context.getTargetItemTypes(), applicableTypes)) {
446        retval = retval || visitControl(ObjectUtils.requireNonNull(childControl), context);
447      }
448    }
449    return retval;
450  }
451
452  @Override
453  public Boolean visitParameter(Parameter parameter, Context context) {
454    assert context != null;
455    if (parameter.getProps() == null) {
456      parameter.setProps(new LinkedList<>());
457    }
458
459    if (parameter.getLinks() == null) {
460      parameter.setLinks(new LinkedList<>());
461    }
462
463    return handleCurrent(
464        parameter,
465        null,
466        null,
467        parameter::getProps,
468        parameter::getLinks,
469        null,
470        context);
471  }
472
473  /**
474   * Visit the control part.
475   *
476   * @param part
477   *          the bound part object
478   * @param context
479   *          the visitor context
480   * @return {@code true} if the removal was applied or {@code false} otherwise
481   */
482  public boolean visitPart(ControlPart part, Context context) {
483    assert context != null;
484    if (part.getProps() == null) {
485      part.setProps(new LinkedList<>());
486    }
487
488    if (part.getLinks() == null) {
489      part.setLinks(new LinkedList<>());
490    }
491
492    if (part.getParts() == null) {
493      part.setParts(new LinkedList<>());
494    }
495
496    boolean retval = handleCurrent(
497        part,
498        null,
499        null,
500        part::getProps,
501        part::getLinks,
502        part::getParts,
503        context);
504
505    return retval || handleChild(
506        TargetType.PART,
507        part::getParts,
508        context::getParts,
509        child -> visitPart(child, context),
510        context);
511  }
512
513  static final class Context {
514    @NonNull
515    private static final Set<TargetType> TITLE_TYPES = ObjectUtils.notNull(
516        Set.of(TargetType.CONTROL, TargetType.PART));
517    @NonNull
518    private static final Set<TargetType> PARAM_TYPES = ObjectUtils.notNull(
519        Set.of(TargetType.CONTROL, TargetType.PARAM));
520    @NonNull
521    private static final Set<TargetType> PROP_TYPES = ObjectUtils.notNull(
522        Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
523    @NonNull
524    private static final Set<TargetType> LINK_TYPES = ObjectUtils.notNull(
525        Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
526    @NonNull
527    private static final Set<TargetType> PART_TYPES = ObjectUtils.notNull(
528        Set.of(TargetType.CONTROL, TargetType.PART));
529
530    @NonNull
531    private final Control control;
532    @NonNull
533    private final Position position;
534    @Nullable
535    private final String byId;
536    @Nullable
537    private final MarkupLine title;
538    @NonNull
539    private final List<Parameter> params;
540    @NonNull
541    private final List<Property> props;
542    @NonNull
543    private final List<Link> links;
544    @NonNull
545    private final List<ControlPart> parts;
546    @NonNull
547    private final Set<TargetType> targetItemTypes;
548
549    public Context(
550        @NonNull Control control,
551        @NonNull Position position,
552        @Nullable String byId,
553        @Nullable MarkupLine title,
554        @NonNull List<Parameter> params,
555        @NonNull List<Property> props,
556        @NonNull List<Link> links,
557        @NonNull List<ControlPart> parts) {
558
559      Set<TargetType> targetItemTypes = ObjectUtils.notNull(EnumSet.allOf(TargetType.class));
560
561      List<String> additionObjects = new LinkedList<>();
562
563      boolean sequenceTarget = true;
564      if (title != null) {
565        targetItemTypes.retainAll(TITLE_TYPES);
566        additionObjects.add("title");
567        sequenceTarget = false;
568      }
569
570      if (!params.isEmpty()) {
571        targetItemTypes.retainAll(PARAM_TYPES);
572        additionObjects.add("param");
573      }
574
575      if (!props.isEmpty()) {
576        targetItemTypes.retainAll(PROP_TYPES);
577        additionObjects.add("prop");
578        sequenceTarget = false;
579      }
580
581      if (!links.isEmpty()) {
582        targetItemTypes.retainAll(LINK_TYPES);
583        additionObjects.add("link");
584        sequenceTarget = false;
585      }
586
587      if (!parts.isEmpty()) {
588        targetItemTypes.retainAll(PART_TYPES);
589        additionObjects.add("part");
590      }
591
592      if (Position.BEFORE.equals(position) || Position.AFTER.equals(position)) {
593        if (!sequenceTarget) {
594          throw new ProfileResolutionEvaluationException(
595              "When using position before or after, one collection of parameters or parts can be specified."
596                  + " Other additions must not be used.");
597        }
598        if (sequenceTarget && !params.isEmpty() && parts.isEmpty()) {
599          targetItemTypes.retainAll(Set.of(TargetType.PARAM));
600        } else if (sequenceTarget && !parts.isEmpty() && params.isEmpty()) {
601          targetItemTypes.retainAll(Set.of(TargetType.PART));
602        } else {
603          throw new ProfileResolutionEvaluationException(
604              "When using position before or after, only one collection of parameters or parts can be specified.");
605        }
606      }
607
608      if (targetItemTypes.isEmpty()) {
609        throw new ProfileResolutionEvaluationException("No parent object supports the requested objects to add: " +
610            additionObjects.stream().collect(CustomCollectors.joiningWithOxfordComma("or")));
611      }
612
613      this.control = control;
614      this.position = position;
615      this.byId = byId;
616      this.title = title;
617      this.params = params;
618      this.props = props;
619      this.links = links;
620      this.parts = parts;
621      this.targetItemTypes = CollectionUtil.unmodifiableSet(targetItemTypes);
622    }
623
624    public Control getControl() {
625      return control;
626    }
627
628    @NonNull
629    public Position getPosition() {
630      return position;
631    }
632
633    @Nullable
634    public String getById() {
635      return byId;
636    }
637
638    @Nullable
639    public MarkupLine getTitle() {
640      return title;
641    }
642
643    @NonNull
644    public List<Parameter> getParams() {
645      return params;
646    }
647
648    @NonNull
649    public List<Property> getProps() {
650      return props;
651    }
652
653    @NonNull
654    public List<Link> getLinks() {
655      return links;
656    }
657
658    @NonNull
659    public List<ControlPart> getParts() {
660      return parts;
661    }
662
663    @NonNull
664    public Set<TargetType> getTargetItemTypes() {
665      return targetItemTypes;
666    }
667
668    public boolean isMatchingType(@NonNull TargetType type) {
669      return getTargetItemTypes().contains(type);
670    }
671
672    public <T> boolean isSequenceTargeted(T targetItem) {
673      TargetType objectType = TargetType.forClass(targetItem.getClass());
674      return (Position.BEFORE.equals(position) || Position.AFTER.equals(position))
675          && (TargetType.PARAM.equals(objectType) && isMatchingType(TargetType.PARAM)
676              || TargetType.PART.equals(objectType) && isMatchingType(TargetType.PART));
677    }
678
679    /**
680     * Determine if the provided {@code obj} is the target of the add.
681     *
682     * @param obj
683     *          the current object
684     * @return {@code true} if the current object applies or {@code false} otherwise
685     */
686    public boolean appliesTo(@NonNull Object obj) {
687      TargetType objectType = TargetType.forClass(obj.getClass());
688
689      boolean retval = objectType != null && isMatchingType(objectType);
690      if (retval) {
691        assert objectType != null;
692
693        // check other criteria
694        String actualId = null;
695
696        switch (objectType) {
697        case CONTROL: {
698          Control control = (Control) obj;
699          actualId = control.getId();
700          break;
701        }
702        case PARAM: {
703          Parameter param = (Parameter) obj;
704          actualId = param.getId();
705          break;
706        }
707        case PART: {
708          ControlPart part = (ControlPart) obj;
709          actualId = part.getId() == null ? null : part.getId();
710          break;
711        }
712        default:
713          throw new UnsupportedOperationException(objectType.fieldName());
714        }
715
716        String byId = getById();
717        if (getById() == null && TargetType.CONTROL.equals(objectType)) {
718          retval = getControl().equals(obj);
719        } else {
720          retval = byId != null && byId.equals(actualId);
721        }
722      }
723      return retval;
724    }
725  }
726}