001/*
002 * SPDX-FileCopyrightText: none
003 * SPDX-License-Identifier: CC0-1.0
004 */
005
006package gov.nist.secauto.oscal.lib.profile.resolver.alter;
007
008import gov.nist.secauto.metaschema.core.datatype.markup.MarkupLine;
009import gov.nist.secauto.metaschema.core.util.CollectionUtil;
010import gov.nist.secauto.metaschema.core.util.CustomCollectors;
011import gov.nist.secauto.metaschema.core.util.ObjectUtils;
012import gov.nist.secauto.oscal.lib.model.Catalog;
013import gov.nist.secauto.oscal.lib.model.CatalogGroup;
014import gov.nist.secauto.oscal.lib.model.Control;
015import gov.nist.secauto.oscal.lib.model.ControlPart;
016import gov.nist.secauto.oscal.lib.model.Link;
017import gov.nist.secauto.oscal.lib.model.Parameter;
018import gov.nist.secauto.oscal.lib.model.Property;
019import gov.nist.secauto.oscal.lib.model.control.catalog.ICatalogVisitor;
020import gov.nist.secauto.oscal.lib.profile.resolver.ProfileResolutionEvaluationException;
021
022import java.util.Collections;
023import java.util.EnumMap;
024import java.util.EnumSet;
025import java.util.LinkedList;
026import java.util.List;
027import java.util.ListIterator;
028import java.util.Locale;
029import java.util.Map;
030import java.util.Set;
031import java.util.concurrent.ConcurrentHashMap;
032import java.util.function.Consumer;
033import java.util.function.Function;
034import java.util.function.Supplier;
035
036import edu.umd.cs.findbugs.annotations.NonNull;
037import edu.umd.cs.findbugs.annotations.Nullable;
038
039@SuppressWarnings("PMD.CouplingBetweenObjects")
040public class AddVisitor implements ICatalogVisitor<Boolean, AddVisitor.Context> {
041  public enum TargetType {
042    CONTROL("control", Control.class),
043    PARAM("param", Parameter.class),
044    PART("part", ControlPart.class);
045
046    @NonNull
047    private static final Map<Class<?>, TargetType> CLASS_TO_TYPE;
048    @NonNull
049    private static final Map<String, TargetType> NAME_TO_TYPE;
050    @NonNull
051    private final String fieldName;
052    @NonNull
053    private final Class<?> clazz;
054
055    static {
056      {
057        Map<Class<?>, TargetType> map = new ConcurrentHashMap<>();
058        for (TargetType type : values()) {
059          map.put(type.getClazz(), type);
060        }
061        CLASS_TO_TYPE = CollectionUtil.unmodifiableMap(map);
062      }
063
064      {
065        Map<String, TargetType> map = new ConcurrentHashMap<>();
066        for (TargetType type : values()) {
067          map.put(type.fieldName(), type);
068        }
069        NAME_TO_TYPE = CollectionUtil.unmodifiableMap(map);
070      }
071    }
072
073    /**
074     * Get the target type associated with the provided {@code clazz}.
075     *
076     * @param clazz
077     *          the class to identify the target type for
078     * @return the associated target type or {@code null} if the class is not
079     *         associated with a target type
080     */
081    @Nullable
082    public static TargetType forClass(@NonNull Class<?> clazz) {
083      Class<?> target = clazz;
084      TargetType retval;
085      // recurse over parent classes to find a match
086      do {
087        retval = CLASS_TO_TYPE.get(target);
088      } while (retval == null && (target = target.getSuperclass()) != null);
089      return retval;
090    }
091
092    /**
093     * Get the target type associated with the provided field {@code name}.
094     *
095     * @param name
096     *          the field name to identify the target type for
097     * @return the associated target type or {@code null} if the name is not
098     *         associated with a target type
099     */
100    @Nullable
101    public static TargetType forFieldName(@Nullable String name) {
102      return name == null ? null : NAME_TO_TYPE.get(name);
103    }
104
105    TargetType(@NonNull String fieldName, @NonNull Class<?> clazz) {
106      this.fieldName = fieldName;
107      this.clazz = clazz;
108    }
109
110    /**
111     * Get the field name associated with the target type.
112     *
113     * @return the name
114     */
115    public String fieldName() {
116      return fieldName;
117    }
118
119    /**
120     * Get the bound class associated with the target type.
121     *
122     * @return the class
123     */
124    public Class<?> getClazz() {
125      return clazz;
126    }
127  }
128
129  public enum Position {
130    BEFORE,
131    AFTER,
132    STARTING,
133    ENDING;
134
135    @NonNull
136    private static final Map<String, Position> NAME_TO_POSITION;
137
138    static {
139      Map<String, Position> map = new ConcurrentHashMap<>();
140      for (Position position : values()) {
141        map.put(position.name().toLowerCase(Locale.ROOT), position);
142      }
143      NAME_TO_POSITION = CollectionUtil.unmodifiableMap(map);
144    }
145
146    /**
147     * Get the position associated with the provided {@code name}.
148     *
149     * @param name
150     *          the name to identify the position for
151     * @return the associated position or {@code null} if the name is not associated
152     *         with a position
153     */
154    @Nullable
155    public static Position forName(@Nullable String name) {
156      return name == null ? null : NAME_TO_POSITION.get(name);
157    }
158  }
159
160  @NonNull
161  private static final AddVisitor INSTANCE = new AddVisitor();
162  private static final Map<TargetType, Set<TargetType>> APPLICABLE_TARGETS;
163
164  static {
165    APPLICABLE_TARGETS = new EnumMap<>(TargetType.class);
166    APPLICABLE_TARGETS.put(TargetType.CONTROL, Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
167    APPLICABLE_TARGETS.put(TargetType.PARAM, Set.of(TargetType.PARAM));
168    APPLICABLE_TARGETS.put(TargetType.PART, Set.of(TargetType.PART));
169  }
170
171  private static Set<TargetType> getApplicableTypes(@NonNull TargetType type) {
172    return APPLICABLE_TARGETS.getOrDefault(type, CollectionUtil.emptySet());
173  }
174
175  /**
176   * Apply the add directive.
177   *
178   * @param control
179   *          the control target
180   * @param position
181   *          the position to apply the content or {@code null}
182   * @param byId
183   *          the identifier of the target or {@code null}
184   * @param title
185   *          a title to set
186   * @param params
187   *          parameters to add
188   * @param props
189   *          properties to add
190   * @param links
191   *          links to add
192   * @param parts
193   *          parts to add
194   * @return {@code true} if the modification was made or {@code false} otherwise
195   * @throws ProfileResolutionEvaluationException
196   *           if a processing error occurred during profile resolution
197   */
198  public static boolean add(
199      @NonNull Control control,
200      @Nullable Position position,
201      @Nullable String byId,
202      @Nullable MarkupLine title,
203      @NonNull List<Parameter> params,
204      @NonNull List<Property> props,
205      @NonNull List<Link> links,
206      @NonNull List<ControlPart> parts) {
207    return INSTANCE.visitControl(
208        control,
209        Context.newContext(
210            control,
211            position == null ? Position.ENDING : position,
212            byId,
213            title,
214            params,
215            props,
216            links,
217            parts));
218  }
219
220  @Override
221  public Boolean visitCatalog(Catalog catalog, Context context) {
222    // not required
223    throw new UnsupportedOperationException("not needed");
224  }
225
226  @Override
227  public Boolean visitGroup(CatalogGroup group, Context context) {
228    // not required
229    throw new UnsupportedOperationException("not needed");
230  }
231
232  /**
233   * If the add applies to the current object, then apply the child objects.
234   * <p>
235   * An add applies if:
236   * <ol>
237   * <li>the {@code targetItem} supports all of the children</li>
238   * <li>the context matches if:
239   * <ul>
240   * <li>the target item's id matches the "by-id"; or</li>
241   * <li>the "by-id" is not defined and the target item is the control matching
242   * the target context</li>
243   * </ul>
244   * </li>
245   * </ol>
246   *
247   * @param <T>
248   *          the type of the {@code targetItem}
249   * @param targetItem
250   *          the current target to process
251   * @param titleConsumer
252   *          a consumer to apply a title to or {@code null} if the object has no
253   *          title field
254   * @param paramsSupplier
255   *          a supplier for the child {@link Parameter} collection
256   * @param propsSupplier
257   *          a supplier for the child {@link Property} collection
258   * @param linksSupplier
259   *          a supplier for the child {@link Link} collection
260   * @param partsSupplier
261   *          a supplier for the child {@link ControlPart} collection
262   * @param context
263   *          the add context
264   * @return {@code true} if a modification was made or {@code false} otherwise
265   */
266  private static <T> boolean handleCurrent(
267      @NonNull T targetItem,
268      @Nullable Consumer<MarkupLine> titleConsumer,
269      @Nullable Supplier<? extends List<Parameter>> paramsSupplier,
270      @Nullable Supplier<? extends List<Property>> propsSupplier,
271      @Nullable Supplier<? extends List<Link>> linksSupplier,
272      @Nullable Supplier<? extends List<ControlPart>> partsSupplier,
273      @NonNull Context context) {
274    boolean retval = false;
275    Position position = context.getPosition();
276    if (context.appliesTo(targetItem) && !context.isSequenceTargeted(targetItem)) {
277      // the target item is the target of the add
278      MarkupLine newTitle = context.getTitle();
279      if (newTitle != null) {
280        assert titleConsumer != null;
281        titleConsumer.accept(newTitle);
282      }
283
284      handleCollection(position, context.getParams(), paramsSupplier);
285      handleCollection(position, context.getProps(), propsSupplier);
286      handleCollection(position, context.getLinks(), linksSupplier);
287      handleCollection(position, context.getParts(), partsSupplier);
288      retval = true;
289    }
290    return retval;
291  }
292
293  private static <T> void handleCollection(
294      @NonNull Position position,
295      @NonNull List<T> newItems,
296      @Nullable Supplier<? extends List<T>> originalCollectionSupplier) {
297    if (originalCollectionSupplier != null) {
298      List<T> oldItems = originalCollectionSupplier.get();
299      if (!newItems.isEmpty()) {
300        if (Position.STARTING.equals(position)) {
301          oldItems.addAll(0, newItems);
302        } else { // ENDING
303          oldItems.addAll(newItems);
304        }
305      }
306    }
307  }
308
309  // private static <T> void handleChild(
310  // @NonNull TargetType itemType,
311  // @NonNull Supplier<? extends List<T>> collectionSupplier,
312  // @Nullable Consumer<T> handler,
313  // @NonNull Context context) {
314  // boolean handleChildren = !Collections.disjoint(context.getTargetItemTypes(),
315  // getApplicableTypes(itemType));
316  // if (handleChildren && handler != null) {
317  // // if the child item type is applicable and there is a handler, iterate over
318  // children
319  // Iterator<T> iter = collectionSupplier.get().iterator();
320  // while (iter.hasNext()) {
321  // T item = iter.next();
322  // if (item != null) {
323  // handler.accept(item);
324  // }
325  // }
326  // }
327  // }
328
329  @SuppressWarnings({ "PMD.CyclomaticComplexity", "PMD.CognitiveComplexity" })
330  private static <T> boolean handleChild(
331      @NonNull TargetType itemType,
332      @NonNull Supplier<? extends List<T>> originalCollectionSupplier,
333      @NonNull Supplier<? extends List<T>> newItemsSupplier,
334      @Nullable Function<T, Boolean> handler,
335      @NonNull Context context) {
336
337    // determine if this child type can match
338    boolean isItemTypeMatch = context.isMatchingType(itemType);
339
340    Set<TargetType> applicableTypes = getApplicableTypes(itemType);
341    boolean descendChild = handler != null && !Collections.disjoint(context.getTargetItemTypes(), applicableTypes);
342
343    boolean retval = false;
344    if (isItemTypeMatch || descendChild) {
345      // if the item type is applicable, attempt to match by id
346      List<T> collection = originalCollectionSupplier.get();
347      ListIterator<T> iter = collection.listIterator();
348      boolean deferred = false;
349      while (iter.hasNext()) {
350        T item = ObjectUtils.requireNonNull(iter.next());
351
352        if (isItemTypeMatch && context.appliesTo(item) && context.isSequenceTargeted(item)) {
353          // if id match, inject the new items into the collection
354          switch (context.getPosition()) {
355          case AFTER: {
356            newItemsSupplier.get().forEach(iter::add);
357            retval = true;
358            break;
359          }
360          case BEFORE: {
361            iter.previous();
362            List<T> adds = newItemsSupplier.get();
363            adds.forEach(iter::add);
364            item = iter.next();
365            retval = true;
366            break;
367          }
368          case STARTING:
369          case ENDING:
370            deferred = true;
371            break;
372          default:
373            throw new UnsupportedOperationException(context.getPosition().name().toLowerCase(Locale.ROOT));
374          }
375        }
376
377        if (descendChild) {
378          assert handler != null;
379
380          // handle child items since they are applicable to the search criteria
381          retval = retval || handler.apply(item);
382        }
383      }
384
385      if (deferred) {
386        List<T> newItems = newItemsSupplier.get();
387        if (Position.ENDING.equals(context.getPosition())) {
388          collection.addAll(newItems);
389          retval = true;
390        } else if (Position.STARTING.equals(context.getPosition())) {
391          collection.addAll(0, newItems);
392          retval = true;
393        }
394      }
395    }
396    return retval;
397  }
398
399  @Override
400  public Boolean visitControl(Control control, Context context) {
401    assert context != null;
402
403    if (control.getParams() == null) {
404      control.setParams(new LinkedList<>());
405    }
406
407    if (control.getProps() == null) {
408      control.setProps(new LinkedList<>());
409    }
410
411    if (control.getLinks() == null) {
412      control.setLinks(new LinkedList<>());
413    }
414
415    if (control.getParts() == null) {
416      control.setParts(new LinkedList<>());
417    }
418
419    boolean retval = handleCurrent(
420        control,
421        control::setTitle,
422        control::getParams,
423        control::getProps,
424        control::getLinks,
425        control::getParts,
426        context);
427
428    // visit params
429    retval = retval || handleChild(
430        TargetType.PARAM,
431        control::getParams,
432        context::getParams,
433        child -> visitParameter(ObjectUtils.notNull(child), context),
434        context);
435
436    // visit parts
437    retval = retval || handleChild(
438        TargetType.PART,
439        control::getParts,
440        context::getParts,
441        child -> visitPart(child, context),
442        context);
443
444    // visit control children
445    for (Control childControl : CollectionUtil.listOrEmpty(control.getControls())) {
446      Set<TargetType> applicableTypes = getApplicableTypes(TargetType.CONTROL);
447      if (!Collections.disjoint(context.getTargetItemTypes(), applicableTypes)) {
448        retval = retval || visitControl(ObjectUtils.requireNonNull(childControl), context);
449      }
450    }
451    return retval;
452  }
453
454  @Override
455  public Boolean visitParameter(Parameter parameter, Context context) {
456    assert context != null;
457    if (parameter.getProps() == null) {
458      parameter.setProps(new LinkedList<>());
459    }
460
461    if (parameter.getLinks() == null) {
462      parameter.setLinks(new LinkedList<>());
463    }
464
465    return handleCurrent(
466        parameter,
467        null,
468        null,
469        parameter::getProps,
470        parameter::getLinks,
471        null,
472        context);
473  }
474
475  /**
476   * Visit the control part.
477   *
478   * @param part
479   *          the bound part object
480   * @param context
481   *          the visitor context
482   * @return {@code true} if the removal was applied or {@code false} otherwise
483   */
484  public boolean visitPart(ControlPart part, Context context) {
485    assert context != null;
486    if (part.getProps() == null) {
487      part.setProps(new LinkedList<>());
488    }
489
490    if (part.getLinks() == null) {
491      part.setLinks(new LinkedList<>());
492    }
493
494    if (part.getParts() == null) {
495      part.setParts(new LinkedList<>());
496    }
497
498    boolean retval = handleCurrent(
499        part,
500        null,
501        null,
502        part::getProps,
503        part::getLinks,
504        part::getParts,
505        context);
506
507    return retval || handleChild(
508        TargetType.PART,
509        part::getParts,
510        context::getParts,
511        child -> visitPart(child, context),
512        context);
513  }
514
515  static final class Context {
516    @NonNull
517    private static final Set<TargetType> TITLE_TYPES = ObjectUtils.notNull(
518        Set.of(TargetType.CONTROL, TargetType.PART));
519    @NonNull
520    private static final Set<TargetType> PARAM_TYPES = ObjectUtils.notNull(
521        Set.of(TargetType.CONTROL, TargetType.PARAM));
522    @NonNull
523    private static final Set<TargetType> PROP_TYPES = ObjectUtils.notNull(
524        Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
525    @NonNull
526    private static final Set<TargetType> LINK_TYPES = ObjectUtils.notNull(
527        Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
528    @NonNull
529    private static final Set<TargetType> PART_TYPES = ObjectUtils.notNull(
530        Set.of(TargetType.CONTROL, TargetType.PART));
531
532    @NonNull
533    private final Control control;
534    @NonNull
535    private final Position position;
536    @Nullable
537    private final String byId;
538    @Nullable
539    private final MarkupLine title;
540    @NonNull
541    private final List<Parameter> params;
542    @NonNull
543    private final List<Property> props;
544    @NonNull
545    private final List<Link> links;
546    @NonNull
547    private final List<ControlPart> parts;
548    @NonNull
549    private final Set<TargetType> targetItemTypes;
550
551    @SuppressWarnings({ "PMD.CyclomaticComplexity", "PMD.CognitiveComplexity", "PMD.NPathComplexity" })
552    public static Context newContext(
553        @NonNull Control control,
554        @NonNull Position position,
555        @Nullable String byId,
556        @Nullable MarkupLine title,
557        @NonNull List<Parameter> params,
558        @NonNull List<Property> props,
559        @NonNull List<Link> links,
560        @NonNull List<ControlPart> parts) {
561      Set<TargetType> targetItemTypes = ObjectUtils.notNull(EnumSet.allOf(TargetType.class));
562      List<String> additionObjects = new LinkedList<>();
563
564      boolean sequenceTarget = true;
565      if (title != null) {
566        targetItemTypes.retainAll(TITLE_TYPES);
567        additionObjects.add("title");
568        sequenceTarget = false;
569      }
570
571      if (!params.isEmpty()) {
572        targetItemTypes.retainAll(PARAM_TYPES);
573        additionObjects.add("param");
574      }
575
576      if (!props.isEmpty()) {
577        targetItemTypes.retainAll(PROP_TYPES);
578        additionObjects.add("prop");
579        sequenceTarget = false;
580      }
581
582      if (!links.isEmpty()) {
583        targetItemTypes.retainAll(LINK_TYPES);
584        additionObjects.add("link");
585        sequenceTarget = false;
586      }
587
588      if (!parts.isEmpty()) {
589        targetItemTypes.retainAll(PART_TYPES);
590        additionObjects.add("part");
591      }
592
593      if (Position.BEFORE.equals(position) || Position.AFTER.equals(position)) {
594        if (!sequenceTarget) {
595          throw new ProfileResolutionEvaluationException(
596              "When using position before or after, one collection of parameters or parts can be specified."
597                  + " Other additions must not be used.");
598        }
599        if (!params.isEmpty() && parts.isEmpty()) {
600          targetItemTypes.retainAll(Set.of(TargetType.PARAM));
601        } else if (!parts.isEmpty() && params.isEmpty()) {
602          targetItemTypes.retainAll(Set.of(TargetType.PART));
603        } else {
604          throw new ProfileResolutionEvaluationException(
605              "When using position before or after, only one collection of parameters or parts can be specified.");
606        }
607      }
608
609      if (targetItemTypes.isEmpty()) {
610        throw new ProfileResolutionEvaluationException("No parent object supports the requested objects to add: " +
611            additionObjects.stream().collect(CustomCollectors.joiningWithOxfordComma("or")));
612      }
613
614      return new Context(
615          control,
616          position,
617          byId,
618          title,
619          params,
620          props,
621          links,
622          parts,
623          targetItemTypes);
624    }
625
626    private Context(
627        @NonNull Control control,
628        @NonNull Position position,
629        @Nullable String byId,
630        @Nullable MarkupLine title,
631        @NonNull List<Parameter> params,
632        @NonNull List<Property> props,
633        @NonNull List<Link> links,
634        @NonNull List<ControlPart> parts,
635        @NonNull Set<TargetType> targetItemTypes) {
636      this.control = control;
637      this.position = position;
638      this.byId = byId;
639      this.title = title;
640      this.params = params;
641      this.props = props;
642      this.links = links;
643      this.parts = parts;
644      this.targetItemTypes = CollectionUtil.unmodifiableSet(targetItemTypes);
645    }
646
647    @NonNull
648    private Control getControl() {
649      return control;
650    }
651
652    @NonNull
653    private Position getPosition() {
654      return position;
655    }
656
657    @Nullable
658    private String getById() {
659      return byId;
660    }
661
662    @Nullable
663    private MarkupLine getTitle() {
664      return title;
665    }
666
667    @NonNull
668    private List<Parameter> getParams() {
669      return params;
670    }
671
672    @NonNull
673    private List<Property> getProps() {
674      return props;
675    }
676
677    @NonNull
678    private List<Link> getLinks() {
679      return links;
680    }
681
682    @NonNull
683    private List<ControlPart> getParts() {
684      return parts;
685    }
686
687    @NonNull
688    private Set<TargetType> getTargetItemTypes() {
689      return targetItemTypes;
690    }
691
692    private boolean isMatchingType(@NonNull TargetType type) {
693      return getTargetItemTypes().contains(type);
694    }
695
696    private <T> boolean isSequenceTargeted(T targetItem) {
697      TargetType objectType = TargetType.forClass(targetItem.getClass());
698      return (Position.BEFORE.equals(position) || Position.AFTER.equals(position))
699          && (TargetType.PARAM.equals(objectType) && isMatchingType(TargetType.PARAM)
700              || TargetType.PART.equals(objectType) && isMatchingType(TargetType.PART));
701    }
702
703    /**
704     * Determine if the provided {@code obj} is the target of the add.
705     *
706     * @param obj
707     *          the current object
708     * @return {@code true} if the current object applies or {@code false} otherwise
709     */
710    private boolean appliesTo(@NonNull Object obj) {
711      TargetType objectType = TargetType.forClass(obj.getClass());
712
713      boolean retval = objectType != null && isMatchingType(objectType);
714      if (retval) {
715        assert objectType != null;
716
717        // check other criteria
718        String actualId = null;
719        switch (objectType) {
720        case CONTROL: {
721          Control control = (Control) obj;
722          actualId = control.getId();
723          break;
724        }
725        case PARAM: {
726          Parameter param = (Parameter) obj;
727          actualId = param.getId();
728          break;
729        }
730        case PART: {
731          ControlPart part = (ControlPart) obj;
732          String partId = part.getId();
733          if (part.getId() != null) {
734            actualId = partId;
735          }
736          break;
737        }
738        default:
739          throw new UnsupportedOperationException(objectType.fieldName());
740        }
741
742        String byId = getById();
743        if (getById() == null && TargetType.CONTROL.equals(objectType)) {
744          retval = getControl().equals(obj);
745        } else {
746          retval = byId != null && byId.equals(actualId);
747        }
748      }
749      return retval;
750    }
751  }
752}